Merge branch 'upstream/master' into addsub_16x8
Change-Id: I51baab7d117fb4656fa09f4b9ff78ef34cbbe8a0
This commit is contained in:
commit
4180f945a7
104
.bazelrc
104
.bazelrc
@ -46,7 +46,6 @@
|
|||||||
# sycl_asan:
|
# sycl_asan:
|
||||||
# sycl_trisycl:
|
# sycl_trisycl:
|
||||||
# mkl: Enable full mkl support.
|
# mkl: Enable full mkl support.
|
||||||
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
|
|
||||||
# tensorrt: Enable Tensorrt support.
|
# tensorrt: Enable Tensorrt support.
|
||||||
# ngraph: Enable ngraph support.
|
# ngraph: Enable ngraph support.
|
||||||
# numa: Enable numa using hwloc.
|
# numa: Enable numa using hwloc.
|
||||||
@ -69,6 +68,7 @@
|
|||||||
# rbe_linux_py3: Linux Python 3 RBE config
|
# rbe_linux_py3: Linux Python 3 RBE config
|
||||||
#
|
#
|
||||||
# rbe_win_py37: Windows Python 3.7 RBE config
|
# rbe_win_py37: Windows Python 3.7 RBE config
|
||||||
|
# rbe_win_py38: Windows Python 3.8 RBE config
|
||||||
#
|
#
|
||||||
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
||||||
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
||||||
@ -136,15 +136,9 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
|||||||
# environment variable "TF_MKL_ROOT" every time before build.
|
# environment variable "TF_MKL_ROOT" every time before build.
|
||||||
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
||||||
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||||
build:mkl -c opt
|
build:mkl -c opt
|
||||||
|
|
||||||
# This config option is used to enable MKL-DNN open source library only,
|
|
||||||
# without depending on MKL binary version.
|
|
||||||
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
|
|
||||||
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
|
|
||||||
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
|
|
||||||
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
|
|
||||||
|
|
||||||
# This config refers to building with CUDA available. It does not necessarily
|
# This config refers to building with CUDA available. It does not necessarily
|
||||||
# mean that we build CUDA op kernels.
|
# mean that we build CUDA op kernels.
|
||||||
build:using_cuda --define=using_cuda=true
|
build:using_cuda --define=using_cuda=true
|
||||||
@ -221,6 +215,11 @@ build --define=grpc_no_ares=true
|
|||||||
# archives in -whole_archive -no_whole_archive.
|
# archives in -whole_archive -no_whole_archive.
|
||||||
build --noincompatible_remove_legacy_whole_archive
|
build --noincompatible_remove_legacy_whole_archive
|
||||||
|
|
||||||
|
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
|
||||||
|
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
|
||||||
|
# https://github.com/tensorflow/community/pull/179
|
||||||
|
build --noincompatible_prohibit_aapt1
|
||||||
|
|
||||||
# Modular TF build options
|
# Modular TF build options
|
||||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||||
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
||||||
@ -241,6 +240,7 @@ build:windows --copt=/w
|
|||||||
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
||||||
# _USE_MATH_DEFINES is defined.
|
# _USE_MATH_DEFINES is defined.
|
||||||
build:windows --copt=/D_USE_MATH_DEFINES
|
build:windows --copt=/D_USE_MATH_DEFINES
|
||||||
|
build:windows --host_copt=/D_USE_MATH_DEFINES
|
||||||
|
|
||||||
# Default paths for TF_SYSTEM_LIBS
|
# Default paths for TF_SYSTEM_LIBS
|
||||||
build:linux --define=PREFIX=/usr
|
build:linux --define=PREFIX=/usr
|
||||||
@ -313,22 +313,26 @@ build:xla --define=with_xla_support=true
|
|||||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||||
# Options when using remote execution
|
# Options when using remote execution
|
||||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||||
|
|
||||||
|
# Flag to enable remote config
|
||||||
|
common --experimental_repo_remote_exec
|
||||||
|
|
||||||
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
||||||
build:rbe --auth_enabled=true
|
build:rbe --google_default_credentials
|
||||||
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
|
|
||||||
build:rbe --bes_backend=buildeventservice.googleapis.com
|
build:rbe --bes_backend=buildeventservice.googleapis.com
|
||||||
build:rbe --bes_best_effort=false
|
|
||||||
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
||||||
build:rbe --bes_timeout=600s
|
build:rbe --bes_timeout=600s
|
||||||
build:rbe --define=EXECUTOR=remote
|
build:rbe --define=EXECUTOR=remote
|
||||||
|
build:rbe --distinct_host_configuration=false
|
||||||
build:rbe --flaky_test_attempts=3
|
build:rbe --flaky_test_attempts=3
|
||||||
build:rbe --jobs=200
|
build:rbe --jobs=200
|
||||||
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
||||||
build:rbe --remote_timeout=3600
|
build:rbe --remote_timeout=3600
|
||||||
build:rbe --spawn_strategy=remote,worker,standalone,local
|
build:rbe --spawn_strategy=remote,worker,standalone,local
|
||||||
test:rbe --test_env=USER=anon
|
test:rbe --test_env=USER=anon
|
||||||
|
# Attempt to minimize the amount of data transfer between bazel and the remote
|
||||||
build:rbe --distinct_host_configuration=false
|
# workers:
|
||||||
|
build:rbe --remote_download_toplevel
|
||||||
|
|
||||||
build:rbe_linux --config=rbe
|
build:rbe_linux --config=rbe
|
||||||
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
||||||
@ -343,6 +347,7 @@ build:rbe_linux --config=avx_linux
|
|||||||
build:rbe_linux --config=short_logs
|
build:rbe_linux --config=short_logs
|
||||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||||
build:rbe_linux --linkopt=-lrt
|
build:rbe_linux --linkopt=-lrt
|
||||||
|
build:rbe_linux --linkopt=-lm
|
||||||
|
|
||||||
build:rbe_cpu_linux --config=rbe_linux
|
build:rbe_cpu_linux --config=rbe_linux
|
||||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||||
@ -351,21 +356,37 @@ build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/to
|
|||||||
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||||
|
|
||||||
build:rbe_linux_cuda_nvcc --config=rbe_linux
|
build:rbe_linux_cuda_base --config=rbe_linux
|
||||||
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
|
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
|
||||||
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
|
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
|
||||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
|
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
|
||||||
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
|
||||||
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
|
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
|
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
|
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
|
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
|
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
|
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_CUDA=1
|
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
|
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
|
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||||
|
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||||
|
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||||
test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||||
|
|
||||||
|
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||||
|
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||||
|
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||||
|
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
|
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
|
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||||
|
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||||
|
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||||
|
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||||
|
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
|
||||||
|
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||||
|
|
||||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||||
|
|
||||||
@ -375,29 +396,33 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
|
|||||||
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
||||||
|
|
||||||
build:rbe_linux_py3 --config=rbe_linux
|
build:rbe_linux_py3 --config=rbe_linux
|
||||||
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
|
|
||||||
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
|
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
|
||||||
|
|
||||||
build:rbe_win --config=rbe
|
build:rbe_win --config=rbe
|
||||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
|
||||||
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
|
||||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
|
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||||
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||||
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||||
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||||
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||||
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
||||||
|
|
||||||
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
||||||
build:rbe_win --define=override_eigen_strong_inline=true
|
build:rbe_win --define=override_eigen_strong_inline=true
|
||||||
|
build:rbe_win --jobs=500
|
||||||
|
|
||||||
build:rbe_win_py37 --config=rbe
|
build:rbe_win_py37 --config=rbe
|
||||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
|
||||||
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
|
||||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
|
||||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||||
|
|
||||||
|
build:rbe_win_py38 --config=rbe
|
||||||
|
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
|
||||||
|
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
|
||||||
|
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
|
||||||
|
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
|
||||||
|
|
||||||
# These you may need to change for your own GCP project.
|
# These you may need to change for your own GCP project.
|
||||||
build:tensorflow_testing_rbe --project_id=tensorflow-testing
|
build:tensorflow_testing_rbe --project_id=tensorflow-testing
|
||||||
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
|
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
|
||||||
@ -407,7 +432,6 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
|
|||||||
|
|
||||||
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
|
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
|
||||||
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
|
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
|
||||||
build:tensorflow_testing_rbe_win --config=rbe_win
|
|
||||||
# END TF REMOTE BUILD EXECUTION OPTIONS
|
# END TF REMOTE BUILD EXECUTION OPTIONS
|
||||||
|
|
||||||
# Default options should come above this line
|
# Default options should come above this line
|
||||||
|
@ -1 +1 @@
|
|||||||
1.2.1
|
2.0.0
|
||||||
|
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
44
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
name: Bug Issue
|
||||||
|
about: Use this template for reporting a bug
|
||||||
|
labels: 'type:bug'
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<em>Please make sure that this is a bug. As per our
|
||||||
|
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||||
|
we only address code/doc bugs, performance issues, feature requests and
|
||||||
|
build/installation issues on GitHub. tag:bug_template</em>
|
||||||
|
|
||||||
|
**System information**
|
||||||
|
- Have I written custom code (as opposed to using a stock
|
||||||
|
example script provided in TensorFlow):
|
||||||
|
- OS Platform and Distribution (e.g.,
|
||||||
|
Linux Ubuntu 16.04):
|
||||||
|
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||||
|
the issue happens on mobile device:
|
||||||
|
- TensorFlow installed from (source or
|
||||||
|
binary): - TensorFlow version (use command below):
|
||||||
|
- Python version: - Bazel
|
||||||
|
version (if compiling from source):
|
||||||
|
- GCC/Compiler version (if compiling from
|
||||||
|
source):
|
||||||
|
- CUDA/cuDNN version: - GPU model and memory:
|
||||||
|
|
||||||
|
You can collect some of this information using our environment capture
|
||||||
|
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||||
|
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||||
|
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||||
|
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||||
|
|
||||||
|
**Describe the current behavior**
|
||||||
|
|
||||||
|
**Describe the expected behavior**
|
||||||
|
|
||||||
|
**Standalone code to reproduce the issue**
|
||||||
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
|
**Other info / logs** Include any logs or source code that would be helpful to
|
||||||
|
diagnose the problem. If including tracebacks, please include the full
|
||||||
|
traceback. Large logs and files should be attached.
|
@ -1,35 +0,0 @@
|
|||||||
---
|
|
||||||
name: Bug/Performance Issue
|
|
||||||
about: Use this template for reporting a bug or a performance issue.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
|
|
||||||
|
|
||||||
**System information**
|
|
||||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
|
||||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
|
||||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
|
||||||
- TensorFlow installed from (source or binary):
|
|
||||||
- TensorFlow version (use command below):
|
|
||||||
- Python version:
|
|
||||||
- Bazel version (if compiling from source):
|
|
||||||
- GCC/Compiler version (if compiling from source):
|
|
||||||
- CUDA/cuDNN version:
|
|
||||||
- GPU model and memory:
|
|
||||||
|
|
||||||
You can collect some of this information using our environment capture
|
|
||||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
|
||||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
|
||||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
|
||||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
|
||||||
|
|
||||||
**Describe the current behavior**
|
|
||||||
|
|
||||||
**Describe the expected behavior**
|
|
||||||
|
|
||||||
**Code to reproduce the issue**
|
|
||||||
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
|
|
||||||
|
|
||||||
**Other info / logs**
|
|
||||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
|
@ -1,6 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: Build/Installation Issue
|
name: Build/Installation Issue
|
||||||
about: Use this template for build/installation issues
|
about: Use this template for build/installation issues
|
||||||
|
labels: 'type:build/install'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
---
|
---
|
||||||
name: Documentation Issue
|
name: Documentation Issue
|
||||||
about: Use this template for documentation related
|
about: Use this template for documentation related issues
|
||||||
labels: 'type:docs'
|
labels: 'type:docs'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
|
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
|
||||||
policy, we only address code/doc bugs, performance issues, feature requests, and
|
policy, we only address code/doc bugs, performance issues, feature requests, and
|
||||||
build/installation issues on GitHub.
|
build/installation issues on GitHub.
|
||||||
|
1
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
1
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
@ -1,6 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: Feature Request
|
name: Feature Request
|
||||||
about: Use this template for raising a feature request
|
about: Use this template for raising a feature request
|
||||||
|
labels: 'type:feature'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
12
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
12
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
@ -1,10 +1,10 @@
|
|||||||
---
|
---
|
||||||
name: TensorFlow Lite Op Request
|
name: TensorFlow Lite Op Request
|
||||||
about: Use this template for reporting ops you are using or missing.
|
about: Use this template for reporting Lite ops you are using or missing
|
||||||
|
labels: 'comp:lite'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
**System information**
|
**System information**
|
||||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||||
- TensorFlow installed from (source or binary):
|
- TensorFlow installed from (source or binary):
|
||||||
@ -17,8 +17,14 @@ about: Use this template for reporting ops you are using or missing.
|
|||||||
# Copy and paste here
|
# Copy and paste here
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Standalone code to reproduce the issue**
|
||||||
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
Also, please include a link to a GraphDef or the model if possible.
|
Also, please include a link to a GraphDef or the model if possible.
|
||||||
|
|
||||||
**Any other info / logs**
|
**Any other info / logs**
|
||||||
|
|
||||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
Include any logs or source code that would be helpful to diagnose the problem.
|
||||||
|
If including tracebacks, please include the full traceback. Large logs and files
|
||||||
|
should be attached.
|
||||||
|
1
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
1
.github/ISSUE_TEMPLATE/50-other-issues.md
vendored
@ -1,6 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: Other Issues
|
name: Other Issues
|
||||||
about: Use this template for any other non-support related issues
|
about: Use this template for any other non-support related issues
|
||||||
|
labels: 'type:others'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: TensorFlow Lite New Converter Issue
|
name: TensorFlow Lite New Converter Issue
|
||||||
about: Use this template for reporting issues during model conversion to TFLite.
|
about: Use this template for reporting issues during model conversion to TFLite
|
||||||
|
labels: 'TFLiteConverter'
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
|
|||||||
|
|
||||||
|
|
||||||
**Command used to run the converter or code if you’re using the Python API**
|
**Command used to run the converter or code if you’re using the Python API**
|
||||||
|
If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
```
|
```
|
||||||
# Copy and paste here the exact command
|
# Copy and paste here the exact command
|
||||||
|
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
45
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
---
|
||||||
|
name: Performance Issue
|
||||||
|
about: Use this template for reporting a performance issue
|
||||||
|
labels: 'type:performance'
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<em>Please make sure that this is an issue related to performance of TensorFlow.
|
||||||
|
As per our
|
||||||
|
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
|
||||||
|
we only address code/doc bugs, performance issues, feature requests and
|
||||||
|
build/installation issues on GitHub. tag:performance_template</em>
|
||||||
|
|
||||||
|
**System information**
|
||||||
|
- Have I written custom code (as opposed to using a stock
|
||||||
|
example script provided in TensorFlow):
|
||||||
|
- OS Platform and Distribution (e.g.,
|
||||||
|
Linux Ubuntu 16.04):
|
||||||
|
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||||
|
the issue happens on mobile device:
|
||||||
|
- TensorFlow installed from (source or
|
||||||
|
binary): - TensorFlow version (use command below):
|
||||||
|
- Python version: - Bazel
|
||||||
|
version (if compiling from source):
|
||||||
|
- GCC/Compiler version (if compiling from
|
||||||
|
source):
|
||||||
|
- CUDA/cuDNN version: - GPU model and memory:
|
||||||
|
|
||||||
|
You can collect some of this information using our environment capture
|
||||||
|
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||||
|
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||||
|
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||||
|
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||||
|
|
||||||
|
**Describe the current behavior**
|
||||||
|
|
||||||
|
**Describe the expected behavior**
|
||||||
|
|
||||||
|
**Standalone code to reproduce the issue**
|
||||||
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
|
**Other info / logs** Include any logs or source code that would be helpful to
|
||||||
|
diagnose the problem. If including tracebacks, please include the full
|
||||||
|
traceback. Large logs and files should be attached.
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
|
|||||||
/tensorflow/python/framework/fast_tensor_util.cpp
|
/tensorflow/python/framework/fast_tensor_util.cpp
|
||||||
/tensorflow/lite/gen/**
|
/tensorflow/lite/gen/**
|
||||||
/tensorflow/lite/tools/make/downloads/**
|
/tensorflow/lite/tools/make/downloads/**
|
||||||
|
/tensorflow/lite/tools/make/gen/**
|
||||||
/api_init_files_list.txt
|
/api_init_files_list.txt
|
||||||
/estimator_api_init_files_list.txt
|
/estimator_api_init_files_list.txt
|
||||||
*.whl
|
*.whl
|
||||||
@ -37,7 +38,9 @@ gradleBuild
|
|||||||
*.pbxproj
|
*.pbxproj
|
||||||
*.xcworkspace
|
*.xcworkspace
|
||||||
/*.podspec
|
/*.podspec
|
||||||
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
|
/tensorflow/lite/**/ios/BUILD
|
||||||
|
/tensorflow/lite/**/objc/BUILD
|
||||||
|
/tensorflow/lite/**/swift/BUILD
|
||||||
/tensorflow/lite/examples/ios/simple/data/*.tflite
|
/tensorflow/lite/examples/ios/simple/data/*.tflite
|
||||||
/tensorflow/lite/examples/ios/simple/data/*.txt
|
/tensorflow/lite/examples/ios/simple/data/*.txt
|
||||||
Podfile.lock
|
Podfile.lock
|
||||||
|
18
README.md
18
README.md
@ -70,7 +70,7 @@ $ python
|
|||||||
3
|
3
|
||||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||||
>>> hello.numpy()
|
>>> hello.numpy()
|
||||||
'Hello, TensorFlow!'
|
b'Hello, TensorFlow!'
|
||||||
```
|
```
|
||||||
|
|
||||||
For more examples, see the
|
For more examples, see the
|
||||||
@ -130,18 +130,20 @@ Build Type | Status
|
|||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
* [TensorFlow.org](https://www.tensorflow.org)
|
* [TensorFlow.org](https://www.tensorflow.org)
|
||||||
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
|
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
|
||||||
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
|
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
|
||||||
* [TensorFlow examples](https://github.com/tensorflow/examples)
|
* [TensorFlow Examples](https://github.com/tensorflow/examples)
|
||||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||||
|
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||||
* [TensorFlow blog](https://blog.tensorflow.org)
|
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||||
|
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||||
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||||
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
|
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
||||||
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
|
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
||||||
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
|
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
|
||||||
|
|
||||||
Learn more about the
|
Learn more about the
|
||||||
[TensorFlow community](https://www.tensorflow.org/community) and how to
|
[TensorFlow community](https://www.tensorflow.org/community) and how to
|
||||||
|
16
RELEASE.md
16
RELEASE.md
@ -1,3 +1,19 @@
|
|||||||
|
# Release 2.0.1
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
|
||||||
|
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
|
||||||
|
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
|
||||||
|
|
||||||
|
|
||||||
|
# Release 1.15.2
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
|
||||||
|
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
|
||||||
|
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
|
||||||
|
|
||||||
|
|
||||||
# Release 2.1.0
|
# Release 2.1.0
|
||||||
|
|
||||||
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.
|
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.
|
||||||
|
@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
|
|||||||
### Known Vulnerabilities
|
### Known Vulnerabilities
|
||||||
|
|
||||||
For a list of known vulnerabilities and security advisories for TensorFlow,
|
For a list of known vulnerabilities and security advisories for TensorFlow,
|
||||||
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
|
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).
|
||||||
|
33
WORKSPACE
33
WORKSPACE
@ -1,13 +1,11 @@
|
|||||||
workspace(name = "org_tensorflow")
|
workspace(name = "org_tensorflow")
|
||||||
|
|
||||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
load("//third_party:repo.bzl", "tf_http_archive")
|
|
||||||
|
|
||||||
tf_http_archive(
|
http_archive(
|
||||||
name = "io_bazel_rules_closure",
|
name = "io_bazel_rules_closure",
|
||||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||||
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
|
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||||
@ -115,3 +113,32 @@ http_archive(
|
|||||||
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Required for dependency @com_github_grpc_grpc
|
||||||
|
|
||||||
|
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||||
|
|
||||||
|
grpc_deps()
|
||||||
|
|
||||||
|
load(
|
||||||
|
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||||
|
"apple_rules_dependencies",
|
||||||
|
)
|
||||||
|
|
||||||
|
apple_rules_dependencies()
|
||||||
|
|
||||||
|
load(
|
||||||
|
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||||
|
"apple_support_dependencies",
|
||||||
|
)
|
||||||
|
|
||||||
|
apple_support_dependencies()
|
||||||
|
|
||||||
|
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||||
|
|
||||||
|
bazel_version_repository(name = "bazel_version")
|
||||||
|
|
||||||
|
load("//third_party/googleapis:repository_rules.bzl", "config_googleapis")
|
||||||
|
|
||||||
|
config_googleapis()
|
||||||
|
|
||||||
|
19
configure.py
19
configure.py
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
|||||||
_TF_WORKSPACE_ROOT = ''
|
_TF_WORKSPACE_ROOT = ''
|
||||||
_TF_BAZELRC = ''
|
_TF_BAZELRC = ''
|
||||||
_TF_CURRENT_BAZEL_VERSION = None
|
_TF_CURRENT_BAZEL_VERSION = None
|
||||||
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||||
_TF_MAX_BAZEL_VERSION = '1.2.1'
|
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||||
|
|
||||||
NCCL_LIB_PATHS = [
|
NCCL_LIB_PATHS = [
|
||||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||||
@ -1155,7 +1155,7 @@ def set_trisycl_include_dir(environ_cp):
|
|||||||
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
|
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
|
||||||
|
|
||||||
|
|
||||||
def system_specific_test_config(env):
|
def system_specific_test_config(environ_cp):
|
||||||
"""Add default build and test flags required for TF tests to bazelrc."""
|
"""Add default build and test flags required for TF tests to bazelrc."""
|
||||||
write_to_bazelrc('test --flaky_test_attempts=3')
|
write_to_bazelrc('test --flaky_test_attempts=3')
|
||||||
write_to_bazelrc('test --test_size_filters=small,medium')
|
write_to_bazelrc('test --test_size_filters=small,medium')
|
||||||
@ -1171,14 +1171,14 @@ def system_specific_test_config(env):
|
|||||||
test_only_filters = ['-oss_serial']
|
test_only_filters = ['-oss_serial']
|
||||||
if is_windows():
|
if is_windows():
|
||||||
test_and_build_filters.append('-no_windows')
|
test_and_build_filters.append('-no_windows')
|
||||||
if env.get('TF_NEED_CUDA', None) == '1':
|
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||||
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
|
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
|
||||||
else:
|
else:
|
||||||
test_and_build_filters.append('-gpu')
|
test_and_build_filters.append('-gpu')
|
||||||
elif is_macos():
|
elif is_macos():
|
||||||
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
|
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
|
||||||
elif is_linux():
|
elif is_linux():
|
||||||
if env.get('TF_NEED_CUDA', None) == '1':
|
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||||
test_and_build_filters.append('-no_gpu')
|
test_and_build_filters.append('-no_gpu')
|
||||||
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
|
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
|
||||||
else:
|
else:
|
||||||
@ -1221,7 +1221,7 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
|
|||||||
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
|
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
|
||||||
compile times, but until 16.4 is officially released, we can't depend on it.
|
compile times, but until 16.4 is officially released, we can't depend on it.
|
||||||
|
|
||||||
See also https://groups.google.com/a/tensorflow.org/g/build/c/SsW98Eo7l3o
|
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||||
|
|
||||||
Because it's very annoying to check this manually (to check the MSVC installed
|
Because it's very annoying to check this manually (to check the MSVC installed
|
||||||
versions, you need to use the registry, and it's not clear if Bazel will be
|
versions, you need to use the registry, and it's not clear if Bazel will be
|
||||||
@ -1390,9 +1390,8 @@ def main():
|
|||||||
else:
|
else:
|
||||||
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
||||||
|
|
||||||
xla_enabled_by_default = is_linux() or is_macos()
|
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
|
||||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
write_to_bazelrc('build --config=xla')
|
||||||
xla_enabled_by_default, 'xla')
|
|
||||||
|
|
||||||
set_action_env_var(
|
set_action_env_var(
|
||||||
environ_cp,
|
environ_cp,
|
||||||
@ -1523,7 +1522,7 @@ def main():
|
|||||||
create_android_ndk_rule(environ_cp)
|
create_android_ndk_rule(environ_cp)
|
||||||
create_android_sdk_rule(environ_cp)
|
create_android_sdk_rule(environ_cp)
|
||||||
|
|
||||||
system_specific_test_config(os.environ)
|
system_specific_test_config(environ_cp)
|
||||||
|
|
||||||
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
|
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
|
||||||
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
|
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
|
||||||
|
@ -187,6 +187,12 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "fuchsia",
|
||||||
|
values = {"cpu": "fuchsia"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "ios_x86_64",
|
name = "ios_x86_64",
|
||||||
values = {
|
values = {
|
||||||
@ -448,19 +454,66 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Specifies via a config setting if this is a mobile build or not, makes
|
||||||
|
# it easier to combine settings later.
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "mobile",
|
||||||
|
match_any = [
|
||||||
|
":android",
|
||||||
|
":chromiumos",
|
||||||
|
":emscripten",
|
||||||
|
":ios",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "lite_protos_legacy",
|
||||||
|
values = {"define": "TENSORFLOW_PROTOS=lite"},
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "full_protos",
|
||||||
|
values = {"define": "TENSORFLOW_PROTOS=full"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "lite_protos",
|
||||||
|
match_any = [":lite_protos_legacy"],
|
||||||
|
)
|
||||||
|
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "mobile_lite_protos",
|
||||||
|
match_all = [
|
||||||
|
":lite_protos",
|
||||||
|
":mobile",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "mobile_full_protos",
|
||||||
|
match_all = [
|
||||||
|
":full_protos",
|
||||||
|
":mobile",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||||
# Instead, please use public APIs or public build rules TF provides.
|
# Instead, please use public APIs or public build rules TF provides.
|
||||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = [
|
packages = [
|
||||||
|
# To pass open source testing in the pip Kokoros.
|
||||||
|
"//bazel_pip/tensorflow/...",
|
||||||
"//learning/brain/swift/x10/...",
|
"//learning/brain/swift/x10/...",
|
||||||
"//perftools/accelerators/xprof/api/...",
|
"//perftools/accelerators/xprof/api/...",
|
||||||
|
"//third_party/py/autograph/...",
|
||||||
|
"//third_party/swift/tensorflow/x10/...",
|
||||||
"//tensorflow/...",
|
"//tensorflow/...",
|
||||||
"//tensorflow_estimator/python/estimator/...",
|
"//tensorflow_estimator/python/estimator/...",
|
||||||
"//tensorflow_models/official/...",
|
"//tensorflow_models/official/...",
|
||||||
"//third_party/py/autograph/...",
|
|
||||||
"//third_party/swift/tensorflow/x10/...",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -479,6 +532,7 @@ bzl_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform:build_config_root_bzl",
|
"//tensorflow/core/platform:build_config_root_bzl",
|
||||||
|
"//tensorflow/core/platform:rules_cc_bzl",
|
||||||
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
||||||
"//third_party/mkl:build_defs_bzl",
|
"//third_party/mkl:build_defs_bzl",
|
||||||
"//third_party/mkl_dnn:build_defs_bzl",
|
"//third_party/mkl_dnn:build_defs_bzl",
|
||||||
@ -493,8 +547,8 @@ cc_library(
|
|||||||
name = "grpc",
|
name = "grpc",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = select({
|
deps = select({
|
||||||
":linux_s390x": ["@grpc//:grpc_unsecure"],
|
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
|
||||||
"//conditions:default": ["@grpc"],
|
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -502,8 +556,8 @@ cc_library(
|
|||||||
name = "grpc++",
|
name = "grpc++",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = select({
|
deps = select({
|
||||||
":linux_s390x": ["@grpc//:grpc++_unsecure"],
|
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
|
||||||
"//conditions:default": ["@grpc//:grpc++"],
|
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -588,6 +642,7 @@ tf_cc_shared_object(
|
|||||||
"//tensorflow/core:gpu_runtime_impl",
|
"//tensorflow/core:gpu_runtime_impl",
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||||
"//tensorflow/core:lib_internal_impl",
|
"//tensorflow/core:lib_internal_impl",
|
||||||
|
"//tensorflow/core/profiler:profiler_impl",
|
||||||
"//tensorflow/stream_executor:stream_executor_impl",
|
"//tensorflow/stream_executor:stream_executor_impl",
|
||||||
"//tensorflow:tf_framework_version_script.lds",
|
"//tensorflow:tf_framework_version_script.lds",
|
||||||
] + tf_additional_binary_deps(),
|
] + tf_additional_binary_deps(),
|
||||||
@ -647,6 +702,7 @@ tf_cc_shared_object(
|
|||||||
"//tensorflow/c:exported_symbols.lds",
|
"//tensorflow/c:exported_symbols.lds",
|
||||||
"//tensorflow/c:version_script.lds",
|
"//tensorflow/c:version_script.lds",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||||
],
|
],
|
||||||
@ -907,7 +963,6 @@ py_library(
|
|||||||
"//conditions:default": [":tf_python_api_gen_v1"],
|
"//conditions:default": [":tf_python_api_gen_v1"],
|
||||||
}) + [
|
}) + [
|
||||||
":root_init_gen",
|
":root_init_gen",
|
||||||
":virtual_root_init_gen",
|
|
||||||
"//tensorflow/python/keras/api:keras_python_api_gen",
|
"//tensorflow/python/keras/api:keras_python_api_gen",
|
||||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
|
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
|
||||||
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",
|
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",
|
||||||
|
@ -35,9 +35,11 @@ import inspect as _inspect
|
|||||||
import logging as _logging
|
import logging as _logging
|
||||||
import os as _os
|
import os as _os
|
||||||
import site as _site
|
import site as _site
|
||||||
|
import six as _six
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
@ -69,13 +71,13 @@ except ImportError:
|
|||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||||
|
|
||||||
try:
|
# Lazy-load estimator.
|
||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||||
_current_module.__path__ = (
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
setattr(_current_module, "estimator", estimator)
|
if _module_dir:
|
||||||
except ImportError:
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
pass
|
setattr(_current_module, "estimator", estimator)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .python.keras.api._v2 import keras
|
from .python.keras.api._v2 import keras
|
||||||
@ -85,6 +87,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
# Enable TF2 behaviors
|
# Enable TF2 behaviors
|
||||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||||
|
@ -22,12 +22,14 @@ import distutils as _distutils
|
|||||||
import inspect as _inspect
|
import inspect as _inspect
|
||||||
import os as _os
|
import os as _os
|
||||||
import site as _site
|
import site as _site
|
||||||
|
import six as _six
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
from tensorflow.python.platform import tf_logging as _logging
|
from tensorflow.python.platform import tf_logging as _logging
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
|
|||||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||||
# lazy loading.
|
# lazy loading.
|
||||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||||
try:
|
|
||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
# Lazy-load estimator.
|
||||||
_current_module.__path__ = (
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
setattr(_current_module, "estimator", estimator)
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
except ImportError:
|
if _module_dir:
|
||||||
pass
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .python.keras.api._v1 import keras
|
from .python.keras.api._v1 import keras
|
||||||
@ -80,6 +83,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
||||||
_CONTRIB_WARNING = """
|
_CONTRIB_WARNING = """
|
||||||
|
@ -54,9 +54,10 @@ filegroup(
|
|||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "pywrap_eager_hdrs",
|
name = "pywrap_required_hdrs",
|
||||||
srcs = [
|
srcs = [
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"python_api.h",
|
||||||
"tf_status_helper.h",
|
"tf_status_helper.h",
|
||||||
"tf_status_internal.h",
|
"tf_status_internal.h",
|
||||||
"tf_tensor_internal.h",
|
"tf_tensor_internal.h",
|
||||||
@ -98,6 +99,17 @@ tf_cuda_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_tf_session_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"python_api.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf_attrtype",
|
name = "tf_attrtype",
|
||||||
hdrs = ["tf_attrtype.h"],
|
hdrs = ["tf_attrtype.h"],
|
||||||
@ -142,7 +154,10 @@ tf_cuda_library(
|
|||||||
"c_api.h",
|
"c_api.h",
|
||||||
],
|
],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//tensorflow/c:__subpackages__"],
|
visibility = [
|
||||||
|
"//tensorflow/c:__subpackages__",
|
||||||
|
"//third_party/llvm/llvm-project:__subpackages__",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":c_api_internal",
|
":c_api_internal",
|
||||||
":tf_attrtype",
|
":tf_attrtype",
|
||||||
@ -230,7 +245,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = select({
|
deps = select({
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -303,7 +318,6 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||||
"//tensorflow/core/common_runtime/eager:context",
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
"//tensorflow/core/common_runtime/eager:execute",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
"//tensorflow/core/platform",
|
"//tensorflow/core/platform",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -525,6 +539,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core/kernels:array",
|
"//tensorflow/core/kernels:array",
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
"//tensorflow/core/kernels:math",
|
"//tensorflow/core/kernels:math",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -537,6 +552,7 @@ tf_cc_test(
|
|||||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
|
tags = ["notsan"], # b/149031034
|
||||||
# We must ensure that the dependencies can be dynamically linked since
|
# We must ensure that the dependencies can be dynamically linked since
|
||||||
# the shared library must be able to use core:framework.
|
# the shared library must be able to use core:framework.
|
||||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
@ -635,6 +651,7 @@ tf_cuda_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
":kernels",
|
":kernels",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -642,6 +659,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/kernels:ops_testutil",
|
"//tensorflow/core/kernels:ops_testutil",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -683,4 +701,5 @@ tf_cuda_library(
|
|||||||
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
|
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
|
||||||
"//tensorflow/python:cpp_shape_inference_proto_cc",
|
"//tensorflow/python:cpp_shape_inference_proto_cc",
|
||||||
],
|
],
|
||||||
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -774,7 +774,7 @@ extern "C" {
|
|||||||
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
|
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
|
||||||
const char* op_type,
|
const char* op_type,
|
||||||
const char* oper_name)
|
const char* oper_name)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||||
return new TF_OperationDescription(graph, op_type, oper_name);
|
return new TF_OperationDescription(graph, op_type, oper_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
|||||||
|
|
||||||
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
||||||
TF_Status* status)
|
TF_Status* status)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
|
||||||
Node* ret = nullptr;
|
Node* ret = nullptr;
|
||||||
|
|
||||||
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
|
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
|
||||||
@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
|||||||
const TF_ImportGraphDefOptions* opts,
|
const TF_ImportGraphDefOptions* opts,
|
||||||
TF_ImportGraphDefResults* tf_results,
|
TF_ImportGraphDefResults* tf_results,
|
||||||
TF_Status* status)
|
TF_Status* status)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||||
const int last_node_id = graph->graph.num_node_ids();
|
const int last_node_id = graph->graph.num_node_ids();
|
||||||
tensorflow::ImportGraphDefResults results;
|
tensorflow::ImportGraphDefResults results;
|
||||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
@ -32,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/net.h"
|
#include "tensorflow/core/platform/net.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
@ -520,72 +520,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
|||||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
|
|
||||||
auto* status = TF_NewStatus();
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
tensorflow::Tensor dst;
|
|
||||||
TF_CHECK_OK(TF_TensorToTensor(t, &dst));
|
|
||||||
LOG(INFO) << dst.DebugString();
|
|
||||||
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_OpPrintDebugString(TFE_Op* op) {
|
|
||||||
VLOG(1) << "TFE_OpPrintDebugString() over " << op;
|
|
||||||
LOG(INFO) << op->operation.DebugString();
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TFE_ExecuteOpNotification {
|
|
||||||
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
|
|
||||||
tensorflow::Notification n;
|
|
||||||
std::unique_ptr<tensorflow::Thread> thread;
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
|
|
||||||
};
|
|
||||||
|
|
||||||
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
|
||||||
TFE_TensorHandle** retvals,
|
|
||||||
int* num_retvals,
|
|
||||||
TF_Status* status) {
|
|
||||||
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
|
||||||
|
|
||||||
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
|
|
||||||
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
|
||||||
[op, retvals, num_retvals, n]() {
|
|
||||||
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
|
||||||
n->n.Notify();
|
|
||||||
}));
|
|
||||||
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_ExecuteOpNotificationWaitAndDelete(
|
|
||||||
TFE_ExecuteOpNotification* notification, TF_Status* status) {
|
|
||||||
if (notification == nullptr) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Passed in notification is a nullptr.");
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (notification->thread == nullptr) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Passed in notification didn't start a thread correctly. Cleaning up "
|
|
||||||
"this notification. Please re-execute the operation to get a new "
|
|
||||||
"notification.");
|
|
||||||
|
|
||||||
delete notification;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
notification->n.WaitForNotification();
|
|
||||||
|
|
||||||
status->status = notification->status->status;
|
|
||||||
|
|
||||||
delete notification;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||||
status->status = tensorflow::errors::Internal(errMsg);
|
status->status = tensorflow::errors::Internal(errMsg);
|
||||||
}
|
}
|
||||||
@ -810,113 +744,6 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
|||||||
status->status = EnableCollectiveOps(server_def, ctx);
|
status->status = EnableCollectiveOps(server_def, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MakeTPUInitializationFunctionDef(
|
|
||||||
const tensorflow::string& tpu_system_device_name,
|
|
||||||
tensorflow::FunctionDef* function_def) {
|
|
||||||
tensorflow::OpDef* signature_def(function_def->mutable_signature());
|
|
||||||
signature_def->set_name("_eager_context_tpu_initialization");
|
|
||||||
signature_def->set_is_stateful(true);
|
|
||||||
signature_def->add_control_output("ConfigureDistributedTPU");
|
|
||||||
tensorflow::OpDef_ArgDef* arg_def(signature_def->add_output_arg());
|
|
||||||
arg_def->set_name("topology_proto");
|
|
||||||
arg_def->set_type(tensorflow::DataType::DT_STRING);
|
|
||||||
tensorflow::NodeDef* configure_node_def(function_def->add_node_def());
|
|
||||||
configure_node_def->set_name("ConfigureDistributedTPU");
|
|
||||||
configure_node_def->set_op("ConfigureDistributedTPU");
|
|
||||||
(*configure_node_def->mutable_attr())["compilation_failure_closes_chips"]
|
|
||||||
.set_b(false);
|
|
||||||
configure_node_def->set_device(tpu_system_device_name);
|
|
||||||
tensorflow::NodeDef* identity_node_def(function_def->add_node_def());
|
|
||||||
identity_node_def->set_name("Identity");
|
|
||||||
identity_node_def->set_op("Identity");
|
|
||||||
identity_node_def->add_input("ConfigureDistributedTPU:topology:0");
|
|
||||||
(*identity_node_def->mutable_attr())["T"].set_type(
|
|
||||||
tensorflow::DataType::DT_STRING);
|
|
||||||
(*function_def->mutable_ret())["topology_proto"] = "Identity:output:0";
|
|
||||||
(*function_def->mutable_control_ret())["ConfigureDistributedTPU"] =
|
|
||||||
"ConfigureDistributedTPU";
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE(iga): ConfigureDistributedTPU is dummy op whose sole purpose is to
|
|
||||||
// trigger DistributedTPURewritePass. This pass actually adds real ops that
|
|
||||||
// initialize the TPU system. Thus, we can't simply run ConfigureDistributedTPU
|
|
||||||
// eagerly. We need to wrap it in a function and trigger the rewrite passes on
|
|
||||||
// it. The easiest way to trigger a rewrite is to run it in a function.
|
|
||||||
|
|
||||||
// Running initialization as an operation rather than calling the underlying C++
|
|
||||||
// implementation directly allows us to run initialization on a remote device
|
|
||||||
// without a separate communication channel.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_InitializeTPUSystem(TFE_Context* ctx,
|
|
||||||
const char* job,
|
|
||||||
TF_Buffer* tpu_topology,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (tpu_topology->data != nullptr) {
|
|
||||||
status->status = InvalidArgument("Passing non-empty TF_Buffer is invalid.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::string tpu_system_device_name = tensorflow::strings::StrCat(
|
|
||||||
"/job:", job, "/replica:0/task:0/device:TPU_SYSTEM:0");
|
|
||||||
tensorflow::Device* tpu_system_device = nullptr;
|
|
||||||
tensorflow::Status lookup_status = ctx->context->FindDeviceFromName(
|
|
||||||
tpu_system_device_name.c_str(), &tpu_system_device);
|
|
||||||
if (!lookup_status.ok() || tpu_system_device == nullptr) {
|
|
||||||
// There are no TPUs to initialize.
|
|
||||||
status->status = tensorflow::errors::NotFound(tensorflow::strings::StrCat(
|
|
||||||
"No TPUs are associated with the specified job '", job, "'"));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::FunctionDef function_def;
|
|
||||||
MakeTPUInitializationFunctionDef(tpu_system_device->name().c_str(),
|
|
||||||
&function_def);
|
|
||||||
tensorflow::string function_name = function_def.signature().name();
|
|
||||||
status->status = ctx->context->AddFunctionDef(function_def);
|
|
||||||
if (!status->status.ok()) return;
|
|
||||||
// Run the function, which may be a remote call. It returns a serialized
|
|
||||||
// topology proto.
|
|
||||||
const tensorflow::AttrTypeMap* attr_map;
|
|
||||||
bool is_function;
|
|
||||||
status->status = tensorflow::AttrTypeMapForOp(function_name.c_str(),
|
|
||||||
&attr_map, &is_function);
|
|
||||||
if (!status->status.ok()) return;
|
|
||||||
tensorflow::EagerOperation call_op(ctx->context, function_name.c_str(),
|
|
||||||
is_function, attr_map);
|
|
||||||
status->status = call_op.SetDeviceName(tpu_system_device_name.c_str());
|
|
||||||
if (!status->status.ok()) return;
|
|
||||||
tensorflow::TensorHandle* remote_topology_handle;
|
|
||||||
int num_retvals = 1;
|
|
||||||
status->status =
|
|
||||||
tensorflow::EagerExecute(&call_op, &remote_topology_handle, &num_retvals);
|
|
||||||
if (!status->status.ok()) return;
|
|
||||||
tensorflow::TensorHandle* local_topology_handle = nullptr;
|
|
||||||
status->status = tensorflow::EagerCopyToDevice(
|
|
||||||
remote_topology_handle, ctx->context, &ctx->context->Executor(),
|
|
||||||
ctx->context->HostCPU(), false, &local_topology_handle);
|
|
||||||
remote_topology_handle->Unref();
|
|
||||||
if (!status->status.ok()) return;
|
|
||||||
const tensorflow::Tensor* topology_proto_tensor;
|
|
||||||
status->status = local_topology_handle->Tensor(&topology_proto_tensor);
|
|
||||||
if (!status->status.ok()) return;
|
|
||||||
status->status = ctx->context->RemoveFunction(function_name);
|
|
||||||
if (!status->status.ok()) return;
|
|
||||||
// The function ran, so we put the result in the return buffer.
|
|
||||||
tensorflow::string result =
|
|
||||||
topology_proto_tensor->flat<tensorflow::tstring>()(0);
|
|
||||||
local_topology_handle->Unref();
|
|
||||||
void* topology_data = tensorflow::port::Malloc(result.size());
|
|
||||||
tpu_topology->data = topology_data;
|
|
||||||
if (tpu_topology->data == nullptr) {
|
|
||||||
status->status = tensorflow::errors::ResourceExhausted(
|
|
||||||
"Failed to allocate memory for topology proto (", result.size(),
|
|
||||||
" bytes)");
|
|
||||||
}
|
|
||||||
memcpy(topology_data, result.c_str(), result.size());
|
|
||||||
tpu_topology->length = result.size();
|
|
||||||
tpu_topology->data_deallocator = [](void* data, size_t length) {
|
|
||||||
tensorflow::port::Free(data);
|
|
||||||
};
|
|
||||||
status->status = tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
|
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
|
||||||
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
|
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
|
||||||
result->num_items = num_items;
|
result->num_items = num_items;
|
||||||
@ -990,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
|||||||
|
|
||||||
const int num_inputs = input_shapes->num_items;
|
const int num_inputs = input_shapes->num_items;
|
||||||
NodeDef node_def;
|
NodeDef node_def;
|
||||||
node_def.set_name(tfe_op->operation.Name());
|
node_def.set_name(tfe_op->operation->Name());
|
||||||
node_def.set_op(tfe_op->operation.Name());
|
node_def.set_op(tfe_op->operation->Name());
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
node_def.add_input("dummy_input");
|
node_def.add_input("dummy_input");
|
||||||
}
|
}
|
||||||
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
|
tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
tfe_op->operation.get())
|
||||||
|
->Attrs()
|
||||||
|
.FillAttrValueMap(node_def.mutable_attr());
|
||||||
|
|
||||||
const tensorflow::OpRegistrationData* op_reg_data;
|
const tensorflow::OpRegistrationData* op_reg_data;
|
||||||
status->status =
|
status->status =
|
||||||
|
@ -188,31 +188,6 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
|||||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||||
TF_Session* session, int tensor_id, TF_Status* status);
|
TF_Session* session, int tensor_id, TF_Status* status);
|
||||||
|
|
||||||
// Prints `handle` in a human readable format to standard output for debugging.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
|
|
||||||
TFE_TensorHandle* handle);
|
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op);
|
|
||||||
|
|
||||||
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
|
|
||||||
|
|
||||||
// Allows invoking a kernel asynchronously, and explicitly returns a
|
|
||||||
// notification that can be waited upon. This always executes the kernel in a
|
|
||||||
// new thread.
|
|
||||||
// 1. `retvals` and `num_retvals` can only be consumed after
|
|
||||||
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
|
|
||||||
// if the return is unsuccessful
|
|
||||||
// 2. These new APIs cannot be used together with the TFE context level async
|
|
||||||
// support.
|
|
||||||
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
|
|
||||||
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Waits to complete the op execution, and cleans up the notification.
|
|
||||||
// Errors reported by op execution are set in `status`.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
|
|
||||||
TFE_ExecuteOpNotification* notification, TF_Status* status);
|
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||||
const char* errMsg);
|
const char* errMsg);
|
||||||
|
|
||||||
@ -297,20 +272,6 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
|||||||
size_t proto_len,
|
size_t proto_len,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// Runs operations necessary to initialize TPU devices associated with `job`
|
|
||||||
// (e.g. "localhost" for local TPUs), returning a serialized TopologyProto (same
|
|
||||||
// result as the "ConfigureDistributedTPU" operation) if TPUs were
|
|
||||||
// available. Sets a NotFound status if no TPUs were found associated with
|
|
||||||
// the job specified.
|
|
||||||
//
|
|
||||||
// TFE_InitializeTPUSystem should only be run once for a given TPU system;
|
|
||||||
// running it multiple times will invalidate tensors/variables placed on the
|
|
||||||
// affected TPUs.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_InitializeTPUSystem(TFE_Context* ctx,
|
|
||||||
const char* job,
|
|
||||||
TF_Buffer* tpu_topology,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Information about the shape of a Tensor and its type.
|
// Information about the shape of a Tensor and its type.
|
||||||
struct TF_ShapeAndType {
|
struct TF_ShapeAndType {
|
||||||
// Number of dimensions. -1 indicates unknown rank.
|
// Number of dimensions. -1 indicates unknown rank.
|
||||||
|
@ -73,21 +73,6 @@ protocol: "grpc"
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI_EXPERIMENTAL, InitializeTPUSystemTest) {
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
TF_Buffer* buf = TF_NewBuffer();
|
|
||||||
TFE_InitializeTPUSystem(ctx, "localhost", buf, status);
|
|
||||||
// Note that this assumes TPUs are not available for this test.
|
|
||||||
CHECK_EQ(TF_NOT_FOUND, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TF_DeleteBuffer(buf);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI_EXPERIMENTAL, IsStateful) {
|
TEST(CAPI_EXPERIMENTAL, IsStateful) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
@ -99,127 +84,6 @@ TEST(CAPI_EXPERIMENTAL, IsStateful) {
|
|||||||
EXPECT_EQ(id, 0);
|
EXPECT_EQ(id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
|
||||||
|
|
||||||
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
|
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
|
||||||
int num_retvals = 1;
|
|
||||||
|
|
||||||
auto* r =
|
|
||||||
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
|
|
||||||
|
|
||||||
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
float product[4] = {0};
|
|
||||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
|
||||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
EXPECT_EQ(7, product[0]);
|
|
||||||
EXPECT_EQ(10, product[1]);
|
|
||||||
EXPECT_EQ(15, product[2]);
|
|
||||||
EXPECT_EQ(22, product[3]);
|
|
||||||
|
|
||||||
TFE_DeleteOp(matmul_op);
|
|
||||||
TFE_DeleteTensorHandle(m);
|
|
||||||
|
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform a send/recv test. Recv blocks, so they need to be executed
|
|
||||||
// asynchronously.
|
|
||||||
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
|
|
||||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
|
||||||
|
|
||||||
// Build a send op.
|
|
||||||
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_OpAddInput(send_op, m, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
string tensor_name = "Tensor";
|
|
||||||
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
|
|
||||||
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
|
|
||||||
tensor_name.size());
|
|
||||||
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
|
||||||
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
|
|
||||||
send_device.size());
|
|
||||||
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
|
|
||||||
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
|
||||||
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
|
|
||||||
recv_device.size());
|
|
||||||
TFE_OpSetAttrBool(send_op, "client_terminated", true);
|
|
||||||
|
|
||||||
// Build a recv op.
|
|
||||||
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
|
|
||||||
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
|
|
||||||
tensor_name.size());
|
|
||||||
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
|
|
||||||
send_device.size());
|
|
||||||
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
|
|
||||||
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
|
|
||||||
recv_device.size());
|
|
||||||
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
|
|
||||||
|
|
||||||
TFE_TensorHandle* send_retvals;
|
|
||||||
int send_num_retvals = 0;
|
|
||||||
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
|
|
||||||
&send_num_retvals, status);
|
|
||||||
|
|
||||||
TFE_TensorHandle* recv_retvals[1] = {nullptr};
|
|
||||||
int recv_num_retvals = 1;
|
|
||||||
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
|
|
||||||
&recv_num_retvals, status);
|
|
||||||
|
|
||||||
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
float product[4] = {0};
|
|
||||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
|
||||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
EXPECT_EQ(1, product[0]);
|
|
||||||
EXPECT_EQ(2, product[1]);
|
|
||||||
EXPECT_EQ(3, product[2]);
|
|
||||||
EXPECT_EQ(4, product[3]);
|
|
||||||
|
|
||||||
TFE_DeleteOp(send_op);
|
|
||||||
TFE_DeleteOp(recv_op);
|
|
||||||
TFE_DeleteTensorHandle(m);
|
|
||||||
|
|
||||||
TFE_DeleteTensorHandle(recv_retvals[0]);
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
}
|
|
||||||
|
|
||||||
class ShapeInferenceTest : public ::testing::Test {
|
class ShapeInferenceTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
ShapeInferenceTest()
|
ShapeInferenceTest()
|
||||||
|
@ -51,7 +51,7 @@ Status ProcessInputs(
|
|||||||
const TF_Graph* fn_body, const char* fn_name, int ninputs,
|
const TF_Graph* fn_body, const char* fn_name, int ninputs,
|
||||||
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
|
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
|
||||||
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
|
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||||
input_tensors->reserve(ninputs);
|
input_tensors->reserve(ninputs);
|
||||||
for (int i = 0; i < ninputs; ++i) {
|
for (int i = 0; i < ninputs; ++i) {
|
||||||
Node* node = &inputs[i].oper->node;
|
Node* node = &inputs[i].oper->node;
|
||||||
@ -87,7 +87,7 @@ Status ProcessInputs(
|
|||||||
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||||
int noutputs, const TF_Output* outputs,
|
int noutputs, const TF_Output* outputs,
|
||||||
std::vector<OutputTensor>* output_tensors)
|
std::vector<OutputTensor>* output_tensors)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||||
output_tensors->reserve(noutputs);
|
output_tensors->reserve(noutputs);
|
||||||
for (int i = 0; i < noutputs; ++i) {
|
for (int i = 0; i < noutputs; ++i) {
|
||||||
Node* node = &outputs[i].oper->node;
|
Node* node = &outputs[i].oper->node;
|
||||||
@ -111,7 +111,7 @@ Status ComputeBodyNodes(
|
|||||||
const TF_Operation* const* opers,
|
const TF_Operation* const* opers,
|
||||||
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
|
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
|
||||||
std::vector<const Node*>* body_nodes)
|
std::vector<const Node*>* body_nodes)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||||
if (num_opers == -1) {
|
if (num_opers == -1) {
|
||||||
for (const Node* node : fn_body->graph.op_nodes()) {
|
for (const Node* node : fn_body->graph.op_nodes()) {
|
||||||
const auto& iter = input_nodes.find(node);
|
const auto& iter = input_nodes.find(node);
|
||||||
|
@ -71,14 +71,14 @@ struct TF_Graph {
|
|||||||
TF_Graph();
|
TF_Graph();
|
||||||
|
|
||||||
tensorflow::mutex mu;
|
tensorflow::mutex mu;
|
||||||
tensorflow::Graph graph GUARDED_BY(mu);
|
tensorflow::Graph graph TF_GUARDED_BY(mu);
|
||||||
|
|
||||||
// Runs shape inference.
|
// Runs shape inference.
|
||||||
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
|
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
|
||||||
|
|
||||||
// Maps from name of an operation to the Node* in 'graph'.
|
// Maps from name of an operation to the Node* in 'graph'.
|
||||||
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
|
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
|
||||||
GUARDED_BY(mu);
|
TF_GUARDED_BY(mu);
|
||||||
|
|
||||||
// The keys of this map are all the active sessions using this graph. Each
|
// The keys of this map are all the active sessions using this graph. Each
|
||||||
// value records whether the graph has been mutated since the corresponding
|
// value records whether the graph has been mutated since the corresponding
|
||||||
@ -94,8 +94,8 @@ struct TF_Graph {
|
|||||||
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
|
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
|
||||||
// status, this should be reverted when possible.
|
// status, this should be reverted when possible.
|
||||||
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
|
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
|
||||||
GUARDED_BY(mu);
|
TF_GUARDED_BY(mu);
|
||||||
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
|
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
|
||||||
|
|
||||||
// Used to link graphs contained in TF_WhileParams to the parent graph that
|
// Used to link graphs contained in TF_WhileParams to the parent graph that
|
||||||
// will eventually contain the full while loop.
|
// will eventually contain the full while loop.
|
||||||
@ -123,7 +123,7 @@ struct TF_Session {
|
|||||||
tensorflow::Session* session;
|
tensorflow::Session* session;
|
||||||
TF_Graph* const graph;
|
TF_Graph* const graph;
|
||||||
|
|
||||||
tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
|
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
|
||||||
int last_num_graph_nodes;
|
int last_num_graph_nodes;
|
||||||
|
|
||||||
// If true, TF_SessionRun and similar methods will call
|
// If true, TF_SessionRun and similar methods will call
|
||||||
@ -169,9 +169,9 @@ struct TF_ApiDefMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||||
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
|
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
|
||||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||||
bool update_docs_called GUARDED_BY(lock);
|
bool update_docs_called TF_GUARDED_BY(lock);
|
||||||
tensorflow::mutex lock;
|
tensorflow::mutex lock;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
|
|||||||
|
|
||||||
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
|
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
|
||||||
const char* mutation_type)
|
const char* mutation_type)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
|
||||||
|
|
||||||
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
|
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
|
||||||
LOCKS_EXCLUDED(session->graph->mu, session->mu);
|
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
|
||||||
|
|
||||||
std::string getTF_OutputDebugString(TF_Output node);
|
std::string getTF_OutputDebugString(TF_Output node);
|
||||||
|
|
||||||
|
@ -45,6 +45,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
|
|||||||
{
|
{
|
||||||
// Load the library.
|
// Load the library.
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TF_Library* lib =
|
string lib_path = tensorflow::GetDataDependencyFilepath(
|
||||||
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
|
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
|
||||||
|
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
|
||||||
TF_Code code = TF_GetCode(status);
|
TF_Code code = TF_GetCode(status);
|
||||||
string status_msg(TF_Message(status));
|
string status_msg(TF_Message(status));
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
@ -1350,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
|
|||||||
|
|
||||||
TEST(CAPI, SavedModel) {
|
TEST(CAPI, SavedModel) {
|
||||||
// Load the saved model.
|
// Load the saved model.
|
||||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
"half_plus_two", "00000123"));
|
||||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||||
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
|
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
|
||||||
TF_Buffer* metagraph = TF_NewBuffer();
|
TF_Buffer* metagraph = TF_NewBuffer();
|
||||||
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, SavedModelNullArgsAreValid) {
|
TEST(CAPI, SavedModelNullArgsAreValid) {
|
||||||
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
|
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
|
||||||
const string saved_model_dir = tensorflow::io::JoinPath(
|
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
|
"half_plus_two", "00000123"));
|
||||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||||
TF_Status* s = TF_NewStatus();
|
TF_Status* s = TF_NewStatus();
|
||||||
const char* tags[] = {tensorflow::kSavedModelTagServe};
|
const char* tags[] = {tensorflow::kSavedModelTagServe};
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_cc_test",
|
||||||
"tf_copts",
|
"tf_copts",
|
||||||
"tf_cuda_cc_test",
|
"tf_cuda_cc_test",
|
||||||
"tf_cuda_library",
|
"tf_cuda_library",
|
||||||
@ -26,8 +27,9 @@ tf_cuda_library(
|
|||||||
"c_api.cc",
|
"c_api.cc",
|
||||||
"c_api_debug.cc",
|
"c_api_debug.cc",
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.cc",
|
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"operation_interface.cc",
|
||||||
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
hdrs = ["c_api.h"],
|
hdrs = ["c_api.h"],
|
||||||
@ -56,6 +58,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
@ -82,18 +85,18 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/distributed_runtime:remote_device",
|
"//tensorflow/core/distributed_runtime:remote_device",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
"//tensorflow/core/profiler/lib:profiler_lib",
|
|
||||||
"//tensorflow/core/profiler/lib:profiler_session",
|
|
||||||
"//tensorflow/core:gpu_runtime",
|
"//tensorflow/core:gpu_runtime",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "pywrap_eager_hdrs",
|
name = "pywrap_required_hdrs",
|
||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"dlpack.h",
|
||||||
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
@ -106,6 +109,7 @@ tf_cuda_library(
|
|||||||
name = "c_api_internal",
|
name = "c_api_internal",
|
||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
hdrs = ["c_api_internal.h"],
|
hdrs = ["c_api_internal.h"],
|
||||||
@ -128,22 +132,9 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/common_runtime/eager:context",
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
"//tensorflow/core/common_runtime/eager:execute",
|
|
||||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"//tensorflow/core/distributed_runtime:remote_device",
|
"@com_google_absl//absl/container:fixed_array",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
|
||||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
|
||||||
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
|
||||||
"//tensorflow/core/profiler/lib:profiler_lib",
|
|
||||||
"//tensorflow/core/profiler/lib:profiler_session",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -215,6 +206,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -272,8 +264,6 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/distributed_runtime:remote_device",
|
"//tensorflow/core/distributed_runtime:remote_device",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
"//tensorflow/core/profiler/rpc:profiler_server",
|
|
||||||
"//tensorflow/core/profiler/rpc/client:capture_profile",
|
|
||||||
"//tensorflow/core:gpu_runtime",
|
"//tensorflow/core:gpu_runtime",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
@ -303,6 +293,27 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "custom_device_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"custom_device_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
":c_api_experimental",
|
||||||
|
":c_api_test_util",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c:c_test_util",
|
||||||
|
"//tensorflow/cc/profiler",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tape",
|
name = "tape",
|
||||||
hdrs = ["tape.h"],
|
hdrs = ["tape.h"],
|
||||||
@ -315,10 +326,37 @@ cc_library(
|
|||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "headers",
|
name = "headers",
|
||||||
srcs = ["c_api.h"],
|
srcs = [
|
||||||
|
"c_api.h",
|
||||||
|
"c_api_experimental.h",
|
||||||
|
"dlpack.h",
|
||||||
|
],
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "dlpack",
|
||||||
|
srcs = ["dlpack.cc"],
|
||||||
|
hdrs = ["dlpack.h"],
|
||||||
|
copts = [
|
||||||
|
"-fexceptions",
|
||||||
|
"-fno-strict-aliasing",
|
||||||
|
],
|
||||||
|
features = ["-use_header_modules"],
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
":c_api_experimental",
|
||||||
|
":c_api_internal",
|
||||||
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@dlpack",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||||
# right now, remove this public rule when no longer needed (it should be
|
# right now, remove this public rule when no longer needed (it should be
|
||||||
# replaced by TF Lite)
|
# replaced by TF Lite)
|
||||||
@ -332,6 +370,7 @@ filegroup(
|
|||||||
exclude = [
|
exclude = [
|
||||||
"c_api_experimental.cc",
|
"c_api_experimental.cc",
|
||||||
"*test*",
|
"*test*",
|
||||||
|
"*dlpack*",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -213,7 +213,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
|
|||||||
TFE_TensorDebugInfo* debug_info);
|
TFE_TensorDebugInfo* debug_info);
|
||||||
|
|
||||||
// Returns the number of dimensions used to represent the tensor on its device.
|
// Returns the number of dimensions used to represent the tensor on its device.
|
||||||
// The number of dimensions used to reprensent the tensor on device can be
|
// The number of dimensions used to represent the tensor on device can be
|
||||||
// different from the number returned by TFE_TensorHandleNumDims.
|
// different from the number returned by TFE_TensorHandleNumDims.
|
||||||
// The return value was current at the time of TFE_TensorDebugInfo creation.
|
// The return value was current at the time of TFE_TensorDebugInfo creation.
|
||||||
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
|
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
|
||||||
|
@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||||
tensorflow::Device* device = handle_->device();
|
tensorflow::Device* device = absl::get<Device*>(handle_->device());
|
||||||
|
|
||||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||||
tensorflow::XlaDevice* xla_device =
|
tensorflow::XlaDevice* xla_device =
|
||||||
|
@ -18,62 +18,27 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
|
||||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||||
const char* raw_device_name, TF_Status* status,
|
const char* raw_device_name, TF_Status* status) {
|
||||||
TFE_Op* op_to_reset) {
|
|
||||||
if (op_to_reset) {
|
if (op_to_reset) {
|
||||||
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
status->status =
|
||||||
op_to_reset);
|
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||||
} else {
|
} else {
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
"op_to_reset should not be nullptr");
|
"op_to_reset should not be nullptr");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
|
||||||
op->operation.ConsumeInput(
|
|
||||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
|
||||||
->Handle());
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
|
||||||
|
|
||||||
bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
|
|
||||||
return profiler->profiler->Status().ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
|
|
||||||
|
|
||||||
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
|
|
||||||
TF_Status* status) {
|
|
||||||
string content;
|
|
||||||
status->status = profiler->profiler->SerializeToString(&content);
|
|
||||||
void* data = tensorflow::port::Malloc(content.length());
|
|
||||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
|
||||||
buf->data = data;
|
|
||||||
buf->length = content.length();
|
|
||||||
buf->data_deallocator = [](void* data, size_t length) {
|
|
||||||
tensorflow::port::Free(data);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_StartProfilerServer(int port) {
|
|
||||||
// Release child thread intentionally. The child thread can be terminated by
|
|
||||||
// terminating the main thread.
|
|
||||||
tensorflow::StartProfilerServer(port).release();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||||
ctx->context->SetShouldStoreGraphs(true);
|
ctx->context->SetShouldStoreGraphs(true);
|
||||||
}
|
}
|
||||||
@ -82,46 +47,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
|||||||
ctx->context->SetShouldStoreGraphs(false);
|
ctx->context->SetShouldStoreGraphs(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
|
||||||
const char* logdir, const char* worker_list,
|
|
||||||
bool include_dataset_ops, int duration_ms,
|
|
||||||
int num_tracing_attempts,
|
|
||||||
TF_Status* status) {
|
|
||||||
tensorflow::Status s =
|
|
||||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Set_TF_Status_from_Status(status, s);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
s = tensorflow::profiler::client::StartTracing(
|
|
||||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
|
||||||
num_tracing_attempts);
|
|
||||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
|
||||||
return s.ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
|
||||||
int monitoring_level, bool display_timestamp,
|
|
||||||
TF_Buffer* result, TF_Status* status) {
|
|
||||||
tensorflow::Status s =
|
|
||||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Set_TF_Status_from_Status(status, s);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
string content;
|
|
||||||
s = tensorflow::profiler::client::Monitor(
|
|
||||||
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
|
|
||||||
void* data = tensorflow::port::Malloc(content.length());
|
|
||||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
|
||||||
result->data = data;
|
|
||||||
result->length = content.length();
|
|
||||||
result->data_deallocator = [](void* data, size_t length) {
|
|
||||||
tensorflow::port::Free(data);
|
|
||||||
};
|
|
||||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||||
int64_t value) {
|
int64_t value) {
|
||||||
cell->cell.IncrementBy(value);
|
cell->cell.IncrementBy(value);
|
||||||
@ -589,8 +514,7 @@ void TFE_DeleteCancellationManager(
|
|||||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||||
TFE_CancellationManager* cancellation_manager,
|
TFE_CancellationManager* cancellation_manager,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
op->operation.SetCancellationManager(
|
status->status = op->operation->SetCancellationManager(cancellation_manager);
|
||||||
&cancellation_manager->cancellation_manager);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||||
@ -619,3 +543,41 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
|||||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||||
return new TFE_Executor(&ctx->context->Executor());
|
return new TFE_Executor(&ctx->context->Executor());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||||
|
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||||
|
ctx->context->HostCPU()->parsed_name());
|
||||||
|
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||||
|
void* data = tensorflow::port::Malloc(str.length());
|
||||||
|
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||||
|
buf->data = data;
|
||||||
|
buf->length = str.length();
|
||||||
|
buf->data_deallocator = [](void* data, size_t length) {
|
||||||
|
tensorflow::port::Free(data);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status) {
|
||||||
|
h->handle->EnableImplicitMirroring();
|
||||||
|
status->status = tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||||
|
TF_Buffer* buf, TF_Status* status) {
|
||||||
|
auto* function_def = ctx->context->FindFunctionDef(function_name);
|
||||||
|
if (function_def == nullptr) {
|
||||||
|
status->status = tensorflow::errors::NotFound(
|
||||||
|
"Unable to find FunctionDef with name: ", function_name);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
string str = function_def->SerializeAsString();
|
||||||
|
void* data = tensorflow::port::Malloc(str.length());
|
||||||
|
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||||
|
buf->data = data;
|
||||||
|
buf->length = str.length();
|
||||||
|
buf->data_deallocator = [](void* data, size_t length) {
|
||||||
|
tensorflow::port::Free(data);
|
||||||
|
};
|
||||||
|
status->status = tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
@ -27,41 +27,12 @@ extern "C" {
|
|||||||
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
|
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
|
||||||
// does not set the device name. If it's not `NULL`, then it attempts to parse
|
// does not set the device name. If it's not `NULL`, then it attempts to parse
|
||||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||||
// than seperately calling it because if the existing op has the same
|
// than separately calling it because if the existing op has the same
|
||||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||||
const char* op_or_function_name,
|
const char* op_or_function_name,
|
||||||
const char* raw_device_name,
|
const char* raw_device_name,
|
||||||
TF_Status* status, TFE_Op* op_to_reset);
|
TF_Status* status);
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// A profiler which will start profiling when creating the object and will stop
|
|
||||||
// when the object is destroyed. It will profile all operations run under the
|
|
||||||
// given TFE_Context. Multiple instance of it can be created, but at most one
|
|
||||||
// of them will profile for each TFE_Context.
|
|
||||||
// Thread-safety: TFE_Profiler is thread-safe.
|
|
||||||
typedef struct TFE_Profiler TFE_Profiler;
|
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
|
|
||||||
TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
|
|
||||||
TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
|
|
||||||
|
|
||||||
// The output string is a binary string of tensorflow.tpu.Trace. User can write
|
|
||||||
// the string to file for offline analysis by tensorboard.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
|
|
||||||
TF_Buffer* buf,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Start a profiler grpc server which listens to specified port. It will start
|
|
||||||
// the server on its own thread. It can be shutdown by terminating tensorflow.
|
|
||||||
// It can be used in both Eager mode and graph mode. Creating multiple profiler
|
|
||||||
// server is allowed. The service defined in
|
|
||||||
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
|
|
||||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
|
|
||||||
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
|
|
||||||
|
|
||||||
// Enables only graph collection in RunMetadata on the functions executed from
|
// Enables only graph collection in RunMetadata on the functions executed from
|
||||||
// this context.
|
// this context.
|
||||||
@ -71,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
|||||||
// this context.
|
// this context.
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
|
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
|
||||||
|
|
||||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
|
||||||
// profiling and save the result into logdir which can be visualized by
|
|
||||||
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
|
|
||||||
// include_dataset_opts to false to profile longer traces. It will block the
|
|
||||||
// caller thread until receives tracing result.
|
|
||||||
// This API is designed for TensorBoard, for end user, please use
|
|
||||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
|
||||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
|
||||||
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
|
||||||
const char* service_addr, const char* logdir, const char* worker_list,
|
|
||||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
|
||||||
// monitoring and return the result in a string. It will block the
|
|
||||||
// caller thread until receiving the monitoring result.
|
|
||||||
// This API is designed for TensorBoard, for end user, please use
|
|
||||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
|
||||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
|
|
||||||
const char* service_addr, int duration_ms, int monitoring_level,
|
|
||||||
bool display_timestamp, TF_Buffer* result, TF_Status* status);
|
|
||||||
|
|
||||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Monitoring Counter APIs.
|
// Monitoring Counter APIs.
|
||||||
@ -434,6 +382,18 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
const char* worker_name,
|
const char* worker_name,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Sync pending nodes in local executors (including the context default executor
|
||||||
|
// and thread executors) and streaming requests to remote executors, and get the
|
||||||
|
// combined status.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// If the TensorHandle is copied to another device as part of an op execution,
|
||||||
|
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||||
|
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
|
||||||
|
TFE_TensorHandle*, TF_Status*);
|
||||||
|
|
||||||
// This function will block till the operation that produces `h` has
|
// This function will block till the operation that produces `h` has
|
||||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||||
@ -458,6 +418,114 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg, TF_Status* status);
|
void* deallocator_arg, TF_Status* status);
|
||||||
|
|
||||||
|
// Retrieves the address space (i.e. job, replia, task) of the local host and
|
||||||
|
// saves it in the buffer.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||||
|
TF_Buffer* buf);
|
||||||
|
|
||||||
|
// APIs for generically dealing with op attributes (e.g. when forwarding them
|
||||||
|
// through custom device implementations).
|
||||||
|
//
|
||||||
|
// TODO(allenl): Currently these are black boxes, but we should have some way to
|
||||||
|
// inspect values. This would let people e.g. copy over most attributes and then
|
||||||
|
// modify some based on their values.
|
||||||
|
|
||||||
|
// A reference to an op's name -> attribute mapping
|
||||||
|
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||||
|
|
||||||
|
// Fetch a struct with a reference to information about attributes of `op`.
|
||||||
|
//
|
||||||
|
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||||
|
|
||||||
|
// Add attributes in `attrs` to `op`.
|
||||||
|
//
|
||||||
|
// Does not overwrite or update existing attributes, but adds new ones.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||||
|
|
||||||
|
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
|
||||||
|
// containing the op name and a map of its attributes.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
|
||||||
|
TF_Buffer* buf,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Set an op's attribute from a serialized AttrValue protocol buffer.
|
||||||
|
//
|
||||||
|
// Analogous to TF_SetAttrValueProto for building graph operations.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||||
|
const char* attr_name,
|
||||||
|
const void* proto,
|
||||||
|
size_t proto_len,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
#define TFE_CUSTOM_DEVICE_VERSION 2
|
||||||
|
|
||||||
|
// Struct to be filled in
|
||||||
|
typedef struct TFE_CustomDevice {
|
||||||
|
int version = TFE_CUSTOM_DEVICE_VERSION;
|
||||||
|
// Method to copy a tensor to the custom device.
|
||||||
|
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
|
||||||
|
TFE_TensorHandle* tensor,
|
||||||
|
TF_Status* status,
|
||||||
|
void* device_info) = nullptr;
|
||||||
|
|
||||||
|
// Method to copy a tensor from the custom device to a target device.
|
||||||
|
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
|
||||||
|
TFE_TensorHandle* tensor,
|
||||||
|
const char* target_device_name,
|
||||||
|
TF_Status* status,
|
||||||
|
void* device_info);
|
||||||
|
|
||||||
|
// Method to execute an operation.
|
||||||
|
void (*execute)(TFE_Context* context, int num_inputs,
|
||||||
|
TFE_TensorHandle** inputs, const char* operation_name,
|
||||||
|
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||||
|
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||||
|
|
||||||
|
// Method to delete a device.
|
||||||
|
void (*delete_device)(void* device_info);
|
||||||
|
} TFE_CustomDevice;
|
||||||
|
|
||||||
|
// Registers a custom device for use with eager execution.
|
||||||
|
//
|
||||||
|
// Eager operations may be placed on this device, e.g. `with
|
||||||
|
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
|
||||||
|
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
|
||||||
|
//
|
||||||
|
// The custom device defines copy operations for moving TensorHandles on and
|
||||||
|
// off, and an an execution operation for named operations. Often execution will
|
||||||
|
// simply wrap op execution on one or more physical devices.
|
||||||
|
//
|
||||||
|
// device_info is an opaque caller-defined type stored with the custom device
|
||||||
|
// which is passed to the functions referenced in the TFE_CustomDevice struct
|
||||||
|
// `device` (execute, delete_device, etc.). It can for example contain the
|
||||||
|
// names of wrapped devices.
|
||||||
|
//
|
||||||
|
// There are currently no graph semantics implemented for registered custom
|
||||||
|
// devices, so executing tf.functions which contain operations placed on custom
|
||||||
|
// devices will fail.
|
||||||
|
//
|
||||||
|
// `device_name` must not name an existing physical or custom device. It must
|
||||||
|
// follow the format:
|
||||||
|
//
|
||||||
|
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
|
||||||
|
//
|
||||||
|
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
|
||||||
|
// the device is not usable. In case of a bad status, `device.delete_device` is
|
||||||
|
// still called on `device_info` (i.e. the caller does not retain ownership).
|
||||||
|
//
|
||||||
|
// This API is highly experimental, and in particular is expected to change when
|
||||||
|
// it starts supporting operations with attributes and when tf.function support
|
||||||
|
// is added.
|
||||||
|
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||||
|
const char* device_name, void* device_info,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||||
|
const char* function_name,
|
||||||
|
TF_Buffer* buf,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
#include "tensorflow/core/protobuf/trace_events.pb.h"
|
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
@ -39,88 +38,6 @@ static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
|||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExecuteWithProfiling(bool async) {
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
TFE_Profiler* profiler = TFE_NewProfiler();
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
|
||||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
|
||||||
int num_retvals = 1;
|
|
||||||
|
|
||||||
// Run op on GPU if it is present.
|
|
||||||
string gpu_device_name;
|
|
||||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
|
||||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status);
|
|
||||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteOp(matmul);
|
|
||||||
TFE_DeleteTensorHandle(m);
|
|
||||||
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
ASSERT_EQ(1, num_retvals);
|
|
||||||
TF_Buffer* profiler_result = TF_NewBuffer();
|
|
||||||
if (async) {
|
|
||||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
|
||||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteExecutor(executor);
|
|
||||||
}
|
|
||||||
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
|
|
||||||
TFE_DeleteProfiler(profiler);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
profiler::Trace profile_proto;
|
|
||||||
EXPECT_TRUE(profile_proto.ParseFromString(
|
|
||||||
{reinterpret_cast<const char*>(profiler_result->data),
|
|
||||||
profiler_result->length}));
|
|
||||||
string profile_proto_str = profile_proto.DebugString();
|
|
||||||
#ifndef TENSORFLOW_USE_ROCM
|
|
||||||
// TODO(rocm): enable once GPU profiling is supported in ROCm mode
|
|
||||||
if (!gpu_device_name.empty()) {
|
|
||||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
// "/host:CPU" is collected by TraceMe
|
|
||||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
|
||||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
|
|
||||||
TF_DeleteBuffer(profiler_result);
|
|
||||||
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
float product[4] = {0};
|
|
||||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
|
||||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
EXPECT_EQ(7, product[0]);
|
|
||||||
EXPECT_EQ(10, product[1]);
|
|
||||||
EXPECT_EQ(15, product[2]);
|
|
||||||
EXPECT_EQ(22, product[3]);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
}
|
|
||||||
TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); }
|
|
||||||
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
|
|
||||||
|
|
||||||
TEST(CAPI, MultipleProfilerSession) {
|
|
||||||
TFE_Profiler* profiler1 = TFE_NewProfiler();
|
|
||||||
EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
|
|
||||||
|
|
||||||
TFE_Profiler* profiler2 = TFE_NewProfiler();
|
|
||||||
EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
|
|
||||||
|
|
||||||
TFE_DeleteProfiler(profiler1);
|
|
||||||
TFE_DeleteProfiler(profiler2);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, MonitoringCounter0) {
|
TEST(CAPI, MonitoringCounter0) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
auto* counter =
|
auto* counter =
|
||||||
|
@ -1,66 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/platform/host_info.h"
|
|
||||||
|
|
||||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
|
||||||
const char* raw_device_name, TF_Status* status,
|
|
||||||
TFE_Op* op_to_reset) {
|
|
||||||
const char* name = op_or_function_name; // Shorthand
|
|
||||||
const tensorflow::AttrTypeMap* types;
|
|
||||||
bool is_function = false;
|
|
||||||
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (op_to_reset && op_to_reset->ctx != ctx) {
|
|
||||||
status->status = tensorflow::errors::Internal(
|
|
||||||
"Cannot reset a TFE_Op from another TFE_Context");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
|
||||||
if (!is_function) {
|
|
||||||
const tensorflow::OpDef* op_def;
|
|
||||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
|
|
||||||
} else if (!ctx->context->FindFunctionByName(name)) {
|
|
||||||
status->status = tensorflow::errors::NotFound(
|
|
||||||
"'", name,
|
|
||||||
"' is neither a type of a primitive operation nor a name "
|
|
||||||
"of a function registered in binary running on ",
|
|
||||||
tensorflow::port::Hostname(),
|
|
||||||
". Make sure the operation or function is "
|
|
||||||
"registered in the binary running in this process.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (op_to_reset) {
|
|
||||||
status->status = op_to_reset->Reset(
|
|
||||||
name, is_function, types, raw_device_name, std::move(inference_ctx));
|
|
||||||
return op_to_reset;
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_Op* new_op =
|
|
||||||
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
|
|
||||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
|
||||||
return new_op;
|
|
||||||
}
|
|
@ -27,12 +27,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/operation_interface.h"
|
||||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
|
||||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -48,7 +48,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
|
||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
struct TFE_ContextOptions {
|
struct TFE_ContextOptions {
|
||||||
@ -89,53 +88,8 @@ struct TFE_TensorDebugInfo {
|
|||||||
std::vector<tensorflow::int64> dev_dims;
|
std::vector<tensorflow::int64> dev_dims;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_OpInferenceContext {
|
|
||||||
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
|
|
||||||
: op_def(op_def) {}
|
|
||||||
|
|
||||||
const tensorflow::OpDef* op_def; // op definition from protobuf
|
|
||||||
int input_arg_idx = 0; // arg definition index for the next input to be added
|
|
||||||
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_Op {
|
struct TFE_Op {
|
||||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
std::unique_ptr<AbstractOperationInterface> operation;
|
||||||
const tensorflow::AttrTypeMap* t,
|
|
||||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
|
|
||||||
: ctx(ctx),
|
|
||||||
operation(ctx->context, op, is_function, t),
|
|
||||||
inference_ctx(std::move(inference_ctx)) {}
|
|
||||||
|
|
||||||
void Clear() {
|
|
||||||
operation.Clear();
|
|
||||||
inference_ctx.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::Status Reset(const char* op, bool is_function,
|
|
||||||
const tensorflow::AttrTypeMap* t,
|
|
||||||
const char* raw_device_name,
|
|
||||||
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
|
|
||||||
inference_ctx = std::move(infer_ctx);
|
|
||||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
|
||||||
nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void AddInput(TFE_TensorHandle* input, TF_Status* status);
|
|
||||||
void Execute(TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status);
|
|
||||||
|
|
||||||
TFE_Context* ctx;
|
|
||||||
tensorflow::EagerOperation operation;
|
|
||||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
|
||||||
};
|
|
||||||
|
|
||||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
|
||||||
const char* raw_device_name, TF_Status* status,
|
|
||||||
TFE_Op* op_to_reset = nullptr);
|
|
||||||
|
|
||||||
struct TFE_Profiler {
|
|
||||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::ProfilerSession> profiler;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_MonitoringCounterCell {
|
struct TFE_MonitoringCounterCell {
|
||||||
@ -282,4 +236,17 @@ struct TFE_Executor {
|
|||||||
tensorflow::EagerExecutor* unowned_executor;
|
tensorflow::EagerExecutor* unowned_executor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||||
|
// that sometimes do not require serialization.
|
||||||
|
struct TFE_OpAttrs {
|
||||||
|
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||||
|
|
||||||
|
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||||
|
const char* op_name)
|
||||||
|
: name(op_name), attributes(value) {}
|
||||||
|
|
||||||
|
const char* name;
|
||||||
|
const tensorflow::AttrBuilder* attributes;
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
|
|||||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||||
|
|
||||||
void TestRemoteExecuteSilentCopies(bool async) {
|
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
// This server def has the task index set to 0.
|
// This server def has the task index set to 0.
|
||||||
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
auto* h1_task2 =
|
auto* h1_task2 =
|
||||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
// Handles are on task0 (local), and task2, but op is on task1.
|
// Handles are on task0 (local), and task2, but op is on task1.
|
||||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||||
TFE_OpSetDevice(matmul, task1_name, status);
|
if (remote) {
|
||||||
|
TFE_OpSetDevice(matmul, task1_name, status);
|
||||||
|
}
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
// TODO(gjn): Add support for waiting on async local mirrors
|
||||||
|
if (!async) {
|
||||||
|
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
h1_task2->handle.get())
|
||||||
|
->Handle();
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
matmul->operation.get());
|
||||||
|
// The input handles should never change since they have been mirrored.
|
||||||
|
ASSERT_EQ(op->GetInput(1), remote_arg);
|
||||||
|
}
|
||||||
|
|
||||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
worker_server2.release();
|
worker_server2.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||||
|
TestRemoteExecuteSilentCopies(false, true);
|
||||||
|
}
|
||||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||||
TestRemoteExecuteSilentCopies(true);
|
TestRemoteExecuteSilentCopies(true, true);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||||
|
TestRemoteExecuteSilentCopies(false, false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||||
|
TestRemoteExecuteSilentCopies(true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
|
@ -17,12 +17,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -363,34 +366,63 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
|
|||||||
TensorHandleCopyBetweenTwoGPUDevices(true);
|
TensorHandleCopyBetweenTwoGPUDevices(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorHandleSilentCopy(bool async) {
|
void TensorHandleSilentCopy(bool async,
|
||||||
|
TFE_ContextDevicePlacementPolicy global_policy,
|
||||||
|
TFE_ContextDevicePlacementPolicy thread_policy,
|
||||||
|
bool cpu_op) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
|
||||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||||
|
if (thread_policy != global_policy) {
|
||||||
|
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
|
||||||
|
}
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
// Disable the test if no GPU is present.
|
// Disable the test if no GPU is present.
|
||||||
string gpu_device_name;
|
string gpu_device_name;
|
||||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
if (cpu_op) {
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
string cpu_device_name;
|
||||||
|
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||||
|
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
|
||||||
|
} else {
|
||||||
|
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||||
|
}
|
||||||
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Validate if the input was replaced with a different TensorHandle
|
||||||
|
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
hcpu->handle.get())
|
||||||
|
->Handle();
|
||||||
|
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
hgpu->handle.get())
|
||||||
|
->Handle();
|
||||||
|
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
matmul->operation.get());
|
||||||
|
|
||||||
|
// The input handles should never change since they have been mirrored.
|
||||||
|
EXPECT_EQ(op->GetInput(0), arg0);
|
||||||
|
EXPECT_EQ(op->GetInput(1), arg1);
|
||||||
|
|
||||||
TFE_DeleteOp(matmul);
|
TFE_DeleteOp(matmul);
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
TFE_DeleteTensorHandle(hgpu);
|
TFE_DeleteTensorHandle(hgpu);
|
||||||
@ -404,57 +436,21 @@ void TensorHandleSilentCopy(bool async) {
|
|||||||
TFE_DeleteExecutor(executor);
|
TFE_DeleteExecutor(executor);
|
||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
TEST(CAPI, TensorHandleSilentCopy) {
|
||||||
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
|
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||||
TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
|
|
||||||
void TensorHandleSilentCopyLocal(bool async) {
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
|
||||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
|
||||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
|
||||||
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
|
|
||||||
TFE_DEVICE_PLACEMENT_SILENT);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Disable the test if no GPU is present.
|
|
||||||
string gpu_device_name;
|
|
||||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
|
||||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
|
||||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
|
||||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
TFE_TensorHandle* retvals[1];
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
TFE_DeleteOp(matmul);
|
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
|
||||||
TFE_DeleteTensorHandle(hgpu);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
TFE_DeleteTensorHandle(hcpu);
|
|
||||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
|
||||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TFE_DeleteExecutor(executor);
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
}
|
}
|
||||||
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
|
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
||||||
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
|
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
||||||
TensorHandleSilentCopyLocal(true);
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
||||||
|
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||||
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
||||||
|
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||||
|
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetAndGetOpDevices(bool async) {
|
void SetAndGetOpDevices(bool async) {
|
||||||
@ -590,6 +586,91 @@ TEST(CAPI, TensorHandleDevices) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExecuteAdd(bool async, bool forward_input) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
|
||||||
|
// If a GPU exists, copy the handle to GPU so that we can exercise
|
||||||
|
// unprotecting a mirror.
|
||||||
|
std::string gpu_device_name;
|
||||||
|
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||||
|
TFE_TensorHandle* n_gpu =
|
||||||
|
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
|
||||||
|
TFE_DeleteTensorHandle(n);
|
||||||
|
n = n_gpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
|
||||||
|
|
||||||
|
// Store pointer to raw buffer for validation of forwarding behaviour.
|
||||||
|
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
|
||||||
|
void* orig_ptr = TF_TensorData(orig);
|
||||||
|
TF_DeleteTensor(orig);
|
||||||
|
|
||||||
|
TFE_Op* add_op = AddOp(ctx, n, m);
|
||||||
|
std::string cpu_device_name;
|
||||||
|
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||||
|
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
if (forward_input) {
|
||||||
|
TFE_DeleteTensorHandle(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
|
||||||
|
if (async) {
|
||||||
|
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||||
|
for (int i = 0; i < 10000; ++i) {
|
||||||
|
TFE_TensorHandle* dummy = nullptr;
|
||||||
|
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(dummy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* retval = nullptr;
|
||||||
|
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||||
|
EXPECT_EQ(1, num_retvals);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
if (!forward_input) {
|
||||||
|
TFE_DeleteTensorHandle(n);
|
||||||
|
}
|
||||||
|
TFE_DeleteOp(add_op);
|
||||||
|
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
|
||||||
|
if (forward_input || async) {
|
||||||
|
EXPECT_EQ(orig_ptr, TF_TensorData(t));
|
||||||
|
} else {
|
||||||
|
EXPECT_NE(orig_ptr, TF_TensorData(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(m);
|
||||||
|
TFE_DeleteTensorHandle(retval);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
float result[100 * 100] = {0};
|
||||||
|
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
|
||||||
|
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
for (int i = 0; i < 100 * 100; ++i) {
|
||||||
|
EXPECT_EQ(2.0f, result[i]);
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
|
||||||
|
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
|
||||||
|
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
|
||||||
|
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
|
||||||
|
|
||||||
void Execute_MatMul_CPU(bool async) {
|
void Execute_MatMul_CPU(bool async) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
@ -1228,6 +1309,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
|||||||
TFE_DeleteTensorHandle(h_shares_tensor);
|
TFE_DeleteTensorHandle(h_shares_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
|
||||||
|
->Attrs()
|
||||||
|
.FillAttrValueMap(&attr_values);
|
||||||
|
return attr_values;
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
@ -1244,8 +1333,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
|||||||
TFE_OpAddInput(minOp, axis, status);
|
TFE_OpAddInput(minOp, axis, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
|
||||||
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||||
EXPECT_NE(attr_found, attr_values.cend());
|
EXPECT_NE(attr_found, attr_values.cend());
|
||||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||||
@ -1284,8 +1372,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
|
|||||||
TFE_OpAddInputList(concatOp, inputs, 2, status);
|
TFE_OpAddInputList(concatOp, inputs, 2, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||||
EXPECT_NE(attr_found, attr_values.cend());
|
EXPECT_NE(attr_found, attr_values.cend());
|
||||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||||
@ -1325,8 +1412,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
|
|||||||
TFE_OpAddInputList(assertOp, data, 3, status);
|
TFE_OpAddInputList(assertOp, data, 3, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
|
||||||
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||||
EXPECT_NE(attr_found, attr_values.cend());
|
EXPECT_NE(attr_found, attr_values.cend());
|
||||||
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
|
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
|
||||||
@ -1362,15 +1448,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
|||||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||||
TFE_OpAddInput(concatOp, dim, status);
|
TFE_OpAddInput(concatOp, dim, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
CHECK(concatOp->inference_ctx);
|
CHECK(concatOp->operation->OpDef());
|
||||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
|
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||||
|
<< "Inference context is still present";
|
||||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
tensorflow::AttrValueMap attr_values;
|
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
|
||||||
EXPECT_EQ(attr_values.find("T"), attr_values.end());
|
EXPECT_EQ(attr_values.find("T"), attr_values.end());
|
||||||
EXPECT_EQ(attr_values.find("N"), attr_values.end());
|
EXPECT_EQ(attr_values.find("N"), attr_values.end());
|
||||||
|
|
||||||
@ -1457,4 +1543,88 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||||
|
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||||
|
TFE_OpAttrs attributes;
|
||||||
|
TFE_OpGetAttrs(var_op, &attributes);
|
||||||
|
|
||||||
|
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddAttrs(copy_op, &attributes);
|
||||||
|
unsigned char is_list = 0;
|
||||||
|
ASSERT_EQ(TF_ATTR_TYPE,
|
||||||
|
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
ASSERT_EQ(TF_ATTR_SHAPE,
|
||||||
|
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
copy_op->operation.get());
|
||||||
|
op->Attrs().FillAttrValueMap(&attr_values);
|
||||||
|
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteOp(var_op);
|
||||||
|
TFE_DeleteOp(copy_op);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||||
|
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||||
|
TFE_OpAttrs attributes;
|
||||||
|
TFE_OpGetAttrs(var_op, &attributes);
|
||||||
|
|
||||||
|
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||||
|
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
tensorflow::NameAttrList name_and_attrs;
|
||||||
|
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||||
|
serialized_attr_values->length));
|
||||||
|
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
|
||||||
|
ASSERT_EQ(tensorflow::DT_INT64,
|
||||||
|
name_and_attrs.attr().find("dtype")->second.type());
|
||||||
|
TF_DeleteBuffer(serialized_attr_values);
|
||||||
|
|
||||||
|
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
|
||||||
|
string serialized_dtype;
|
||||||
|
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
|
||||||
|
&serialized_dtype));
|
||||||
|
TFE_OpSetAttrValueProto(
|
||||||
|
second_var_op, "dtype",
|
||||||
|
reinterpret_cast<const void*>(serialized_dtype.c_str()),
|
||||||
|
serialized_dtype.length(), status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
second_var_op->operation.get());
|
||||||
|
op->Attrs().FillAttrValueMap(&attr_values);
|
||||||
|
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteOp(var_op);
|
||||||
|
TFE_DeleteOp(second_var_op);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
|
|||||||
return th;
|
return th;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, a, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, b, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||||
|
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
|||||||
// Return a tensor handle containing a 3x2 matrix of floats
|
// Return a tensor handle containing a 3x2 matrix of floats
|
||||||
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||||
|
|
||||||
|
// Return an add op multiplying `a` by `b`.
|
||||||
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||||
|
|
||||||
// Return a matmul op multiplying `a` by `b`.
|
// Return a matmul op multiplying `a` by `b`.
|
||||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||||
|
|
||||||
|
398
tensorflow/c/eager/custom_device_test.cc
Normal file
398
tensorflow/c/eager/custom_device_test.cc
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// A simple logging device to test custom device registration.
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct LoggingDevice {
|
||||||
|
tensorflow::string device_name;
|
||||||
|
tensorflow::string underlying_device;
|
||||||
|
// Set to true whenever a TensorHandle is copied onto the device
|
||||||
|
bool* arrived_flag;
|
||||||
|
// Set to true whenever an operation is executed
|
||||||
|
bool* executed_flag;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LoggedTensor {
|
||||||
|
TFE_TensorHandle* tensor;
|
||||||
|
LoggedTensor() = delete;
|
||||||
|
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||||
|
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||||
|
};
|
||||||
|
|
||||||
|
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||||
|
delete reinterpret_cast<LoggedTensor*>(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||||
|
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||||
|
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||||
|
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
}
|
||||||
|
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||||
|
return TFE_NewTensorHandleFromDeviceMemory(
|
||||||
|
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||||
|
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||||
|
TFE_TensorHandle* tensor,
|
||||||
|
TF_Status* status, void* device_info) {
|
||||||
|
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||||
|
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||||
|
tensor, context, dev->underlying_device.c_str(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
auto dst = std::make_unique<LoggedTensor>(t);
|
||||||
|
*(dev->arrived_flag) = true;
|
||||||
|
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||||
|
TFE_TensorHandle* tensor,
|
||||||
|
const char* target_device_name,
|
||||||
|
TF_Status* status,
|
||||||
|
void* device_info) {
|
||||||
|
TF_SetStatus(status, TF_INTERNAL,
|
||||||
|
"Trying to copy a tensor out of a logging device.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||||
|
TFE_TensorHandle** inputs, const char* operation_name,
|
||||||
|
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||||
|
TFE_TensorHandle** outputs, TF_Status* s,
|
||||||
|
void* device_info) {
|
||||||
|
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||||
|
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
TFE_OpAddAttrs(op, attributes);
|
||||||
|
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||||
|
for (int j = 0; j < num_inputs; ++j) {
|
||||||
|
TFE_TensorHandle* input = inputs[j];
|
||||||
|
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
if (dev->device_name == input_device) {
|
||||||
|
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||||
|
TFE_TensorHandleDevicePointer(input, s));
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
TFE_OpAddInput(op, t->tensor, s);
|
||||||
|
} else {
|
||||||
|
TFE_OpAddInput(op, input, s);
|
||||||
|
}
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
}
|
||||||
|
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||||
|
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||||
|
for (auto* handle : op_outputs) {
|
||||||
|
unwrapped_outputs.push_back(handle);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < *num_outputs; ++i) {
|
||||||
|
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||||
|
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||||
|
std::move(logged_tensor), s);
|
||||||
|
}
|
||||||
|
*(dev->executed_flag) = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeleteLoggingDevice(void* device_info) {
|
||||||
|
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||||
|
bool* arrived_flag, bool* executed_flag,
|
||||||
|
TF_Status* status) {
|
||||||
|
TFE_CustomDevice custom_device;
|
||||||
|
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||||
|
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||||
|
custom_device.delete_device = &DeleteLoggingDevice;
|
||||||
|
custom_device.execute = &LoggingDeviceExecute;
|
||||||
|
LoggingDevice* device = new LoggingDevice;
|
||||||
|
device->arrived_flag = arrived_flag;
|
||||||
|
device->executed_flag = executed_flag;
|
||||||
|
device->device_name = name;
|
||||||
|
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* context = TFE_NewContext(opts, status.get());
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||||
|
ASSERT_FALSE(arrived);
|
||||||
|
TFE_TensorHandle* hdevice =
|
||||||
|
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||||
|
ASSERT_TRUE(arrived);
|
||||||
|
ASSERT_FALSE(executed);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||||
|
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
|
||||||
|
TFE_OpSetDevice(matmul.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* retval;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(retval);
|
||||||
|
TFE_DeleteTensorHandle(hcpu);
|
||||||
|
TFE_DeleteTensorHandle(hdevice);
|
||||||
|
TFE_DeleteContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
const char* custom_device_name =
|
||||||
|
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
|
||||||
|
status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
|
||||||
|
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
|
||||||
|
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||||
|
tensorflow::string(custom_device_name));
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpReset(reused_op.get(), "Identity",
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||||
|
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Create a variable handle placed on the custom device.
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||||
|
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||||
|
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* var_handle = nullptr;
|
||||||
|
int num_retvals = 1;
|
||||||
|
executed = false;
|
||||||
|
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||||
|
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||||
|
|
||||||
|
// Assign to the variable, copying to the custom device.
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||||
|
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||||
|
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
|
||||||
|
// Read the variable's value.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 1;
|
||||||
|
TFE_TensorHandle* var_value = nullptr;
|
||||||
|
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
auto value_cleaner = tensorflow::gtl::MakeCleanup(
|
||||||
|
[var_value]() { TFE_DeleteTensorHandle(var_value); });
|
||||||
|
ASSERT_EQ(tensorflow::string(name),
|
||||||
|
tensorflow::string(
|
||||||
|
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||||
|
TFE_TensorHandle* var_value_unpacked =
|
||||||
|
reinterpret_cast<LoggedTensor*>(
|
||||||
|
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||||
|
->tensor;
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||||
|
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
|
||||||
|
|
||||||
|
// Free the backing buffer for the variable.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Create a variable handle placed on the custom device.
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||||
|
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||||
|
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* var_handle = nullptr;
|
||||||
|
int num_retvals = 1;
|
||||||
|
executed = false;
|
||||||
|
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||||
|
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||||
|
|
||||||
|
// Assign to the variable, copying to the custom device.
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||||
|
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||||
|
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
|
||||||
|
// Read the variable's value.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 1;
|
||||||
|
TFE_TensorHandle* var_value = nullptr;
|
||||||
|
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||||
|
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||||
|
<< "Execution should fail because the variable is being used on the "
|
||||||
|
"wrong device.";
|
||||||
|
// Free the backing buffer for the variable.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
|
||||||
|
status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||||
|
<< TF_Message(status.get());
|
||||||
|
|
||||||
|
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||||
|
<< TF_Message(status.get());
|
||||||
|
|
||||||
|
RegisterLoggingDevice(context.get(),
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||||
|
&arrived, &executed, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||||
|
<< TF_Message(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
334
tensorflow/c/eager/dlpack.cc
Normal file
334
tensorflow/c/eager/dlpack.cc
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/dlpack.h"
|
||||||
|
|
||||||
|
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_reference.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Managing context for the DLManagedTensor, will manage the lifetime of
|
||||||
|
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
|
||||||
|
// original framework of destruction, and this context will be deleted also.
|
||||||
|
struct TfDlManagedTensorCtx {
|
||||||
|
TensorReference reference;
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
std::vector<int64_t> strides;
|
||||||
|
DLManagedTensor tensor;
|
||||||
|
|
||||||
|
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Gets tensor from eager tensor handle.
|
||||||
|
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"The passed in handle is a nullptr");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tensorflow::TensorHandle* handle =
|
||||||
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||||
|
->Handle();
|
||||||
|
|
||||||
|
if (handle->IsRemote()) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"DLPack doesn't support remote tensor");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const tensorflow::Tensor* tensor;
|
||||||
|
status->status = handle->Tensor(&tensor);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deleter for DLManagedTensor
|
||||||
|
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
||||||
|
TfDlManagedTensorCtx* owner =
|
||||||
|
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
|
||||||
|
owner->reference.Unref();
|
||||||
|
delete owner;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts TF_DATAType to DLPack data type.
|
||||||
|
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||||
|
DLDataType dtype;
|
||||||
|
dtype.lanes = 1;
|
||||||
|
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
||||||
|
switch (data_type) {
|
||||||
|
case TF_DataType::TF_HALF:
|
||||||
|
case TF_DataType::TF_FLOAT:
|
||||||
|
case TF_DataType::TF_DOUBLE:
|
||||||
|
dtype.code = DLDataTypeCode::kDLFloat;
|
||||||
|
break;
|
||||||
|
case TF_DataType::TF_INT8:
|
||||||
|
case TF_DataType::TF_INT16:
|
||||||
|
case TF_DataType::TF_INT32:
|
||||||
|
case TF_DataType::TF_INT64:
|
||||||
|
dtype.code = DLDataTypeCode::kDLInt;
|
||||||
|
break;
|
||||||
|
case TF_DataType::TF_BOOL:
|
||||||
|
case TF_DataType::TF_UINT8:
|
||||||
|
case TF_DataType::TF_UINT16:
|
||||||
|
case TF_DataType::TF_UINT32:
|
||||||
|
case TF_DataType::TF_UINT64:
|
||||||
|
dtype.code = DLDataTypeCode::kDLUInt;
|
||||||
|
break;
|
||||||
|
case TF_DataType::TF_BFLOAT16:
|
||||||
|
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
DataType_Name(static_cast<DataType>(data_type)),
|
||||||
|
" is not supported by dlpack");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets DLPack's DLContext from eager tensor handle.
|
||||||
|
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
DLContext ctx;
|
||||||
|
const char* device_name = h->handle->DeviceName(&status->status);
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||||
|
std::string device_type = parsed_name.type;
|
||||||
|
int device_id = 0;
|
||||||
|
if (parsed_name.has_id) {
|
||||||
|
device_id = parsed_name.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.device_id = device_id;
|
||||||
|
if (device_type == "CPU") {
|
||||||
|
ctx.device_type = DLDeviceType::kDLCPU;
|
||||||
|
} else if (device_type == "GPU") {
|
||||||
|
ctx.device_type = DLDeviceType::kDLGPU;
|
||||||
|
} else {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Unsupported Device Type for dlpack");
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts DLContext to TF device name.
|
||||||
|
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
||||||
|
TF_Status* status) {
|
||||||
|
switch (ctx.device_type) {
|
||||||
|
case DLDeviceType::kDLCPU:
|
||||||
|
return "CPU:0";
|
||||||
|
case DLDeviceType::kDLGPU:
|
||||||
|
return absl::StrCat("GPU:", ctx.device_id);
|
||||||
|
default:
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts DLPack data type to TF_DATATYPE.
|
||||||
|
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
|
||||||
|
TF_DataType* tf_dtype) {
|
||||||
|
switch (dtype.code) {
|
||||||
|
case DLDataTypeCode::kDLUInt:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 8:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT8;
|
||||||
|
return Status::OK();
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT16;
|
||||||
|
return Status::OK();
|
||||||
|
case 32:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT32;
|
||||||
|
return Status::OK();
|
||||||
|
case 64:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT64;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
|
||||||
|
dtype.bits);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
case DLDataTypeCode::kDLInt:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 8:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT8;
|
||||||
|
return Status::OK();
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT16;
|
||||||
|
return Status::OK();
|
||||||
|
case 32:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT32;
|
||||||
|
return Status::OK();
|
||||||
|
case 64:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT64;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
|
||||||
|
dtype.bits);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
case DLDataTypeCode::kDLFloat:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_HALF;
|
||||||
|
return Status::OK();
|
||||||
|
case 32:
|
||||||
|
*tf_dtype = TF_DataType::TF_FLOAT;
|
||||||
|
return Status::OK();
|
||||||
|
case 64:
|
||||||
|
*tf_dtype = TF_DataType::TF_DOUBLE;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
|
||||||
|
dtype.bits);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case DLDataTypeCode::kDLBfloat:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_BFLOAT16;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Unsupported BFloat bits: ", dtype.bits);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
|
||||||
|
dtype.code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wraps the deleter function of DLManagedTensor to match the function signature
|
||||||
|
// TFE_NewTensorHandleFromDeviceMemory.
|
||||||
|
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||||
|
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
||||||
|
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks whether the stride array matches the layout of compact, row-majored
|
||||||
|
// data.
|
||||||
|
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||||
|
int ndim) {
|
||||||
|
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = ndim - 2; i >= 0; --i) {
|
||||||
|
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||||
|
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
|
||||||
|
if (dlMTensor->deleter != nullptr) {
|
||||||
|
dlMTensor->deleter(dlMTensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||||
|
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||||
|
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||||
|
|
||||||
|
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||||
|
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||||
|
|
||||||
|
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||||
|
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||||
|
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||||
|
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||||
|
int ndim = tensor->dims();
|
||||||
|
dlm_tensor->dl_tensor.ndim = ndim;
|
||||||
|
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||||
|
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||||
|
|
||||||
|
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||||
|
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||||
|
shape_arr->resize(ndim);
|
||||||
|
stride_arr->resize(ndim, 1);
|
||||||
|
for (int i = 0; i < ndim; i++) {
|
||||||
|
(*shape_arr)[i] = tensor->dim_size(i);
|
||||||
|
}
|
||||||
|
for (int i = ndim - 2; i >= 0; --i) {
|
||||||
|
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||||
|
// There are two ways to represent compact row-major data
|
||||||
|
// 1) nullptr indicates tensor is compact and row-majored.
|
||||||
|
// 2) fill in the strides array as the real case for compact row-major data.
|
||||||
|
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||||
|
// argument properly.
|
||||||
|
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||||
|
dlm_tensor->dl_tensor.byte_offset =
|
||||||
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
|
return static_cast<void*>(dlm_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||||
|
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||||
|
absl::optional<std::string> device_name =
|
||||||
|
DeviceNameFromDlContext(dl_tensor->ctx, status);
|
||||||
|
if (!device_name.has_value()) {
|
||||||
|
status->status =
|
||||||
|
tensorflow::errors::InvalidArgument("Unsupported Device Type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TF_DataType dtype;
|
||||||
|
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
|
||||||
|
if (!s.ok()) {
|
||||||
|
status->status = std::move(s);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int num_dims = dl_tensor->ndim;
|
||||||
|
const int64_t* dims = dl_tensor->shape;
|
||||||
|
void* data = dl_tensor->data;
|
||||||
|
|
||||||
|
size_t total_bytes = dl_tensor->dtype.bits / 8;
|
||||||
|
for (int i = 0; i < num_dims; i++) {
|
||||||
|
total_bytes *= dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dl_tensor->strides != nullptr &&
|
||||||
|
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
||||||
|
num_dims)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Invalid strides array from DLPack");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||||
|
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||||
|
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||||
|
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
39
tensorflow/c/eager/dlpack.h
Normal file
39
tensorflow/c/eager/dlpack.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_DLPACK_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// PyCapsule name for DLPack Tensor
|
||||||
|
const char* const kDlTensorCapsuleName = "dltensor";
|
||||||
|
|
||||||
|
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
|
||||||
|
// void* for further PyCapsule construction.
|
||||||
|
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||||
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_DLPACK_H_
|
312
tensorflow/c/eager/operation_interface.cc
Normal file
312
tensorflow/c/eager/operation_interface.cc
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/operation_interface.h"
|
||||||
|
|
||||||
|
#include "absl/container/fixed_array.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
OperationInterface::OperationInterface(TFE_Context* ctx)
|
||||||
|
: operation_(ctx->context) {}
|
||||||
|
|
||||||
|
const string& OperationInterface::DeviceName() const {
|
||||||
|
absl::variant<Device*, CustomDevice*> variant_device =
|
||||||
|
(operation_.Device() == kVariantDeviceNull)
|
||||||
|
? operation_.EagerContext().HostCPU()
|
||||||
|
: operation_.Device();
|
||||||
|
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||||
|
variant_device);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetDeviceName(const char* name) {
|
||||||
|
return operation_.SetDeviceName(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrString(const char* attr_name,
|
||||||
|
const char* data, size_t length) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrType(const char* attr_name,
|
||||||
|
TF_DataType value) {
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrShape(const char* attr_name,
|
||||||
|
const int64_t* dims,
|
||||||
|
const int num_dims) {
|
||||||
|
if (num_dims > TensorShape::MaxDimensions()) {
|
||||||
|
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||||
|
num_dims,
|
||||||
|
" dimensions which is over the limit of ",
|
||||||
|
TensorShape::MaxDimensions(), ".");
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShapeProto proto;
|
||||||
|
if (num_dims < 0) {
|
||||||
|
proto.set_unknown_rank(true);
|
||||||
|
} else {
|
||||||
|
for (int d = 0; d < num_dims; ++d) {
|
||||||
|
proto.add_dim()->set_size(dims[d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, proto);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFunction(
|
||||||
|
const char* attr_name,
|
||||||
|
const std::unique_ptr<AbstractOperationInterface>& value) {
|
||||||
|
AttrValue attr_value;
|
||||||
|
NameAttrList* func = attr_value.mutable_func();
|
||||||
|
func->set_name(value->Name());
|
||||||
|
OperationInterface* value_operation =
|
||||||
|
tensorflow::down_cast<OperationInterface*>(value.get());
|
||||||
|
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
|
||||||
|
const char* data,
|
||||||
|
size_t length) {
|
||||||
|
AttrValue attr_value;
|
||||||
|
NameAttrList* func = attr_value.mutable_func();
|
||||||
|
func->set_name(data, length);
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrTensor(const char* attr_name,
|
||||||
|
TF_Tensor* tensor) {
|
||||||
|
Tensor t;
|
||||||
|
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, t);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrStringList(const char* attr_name,
|
||||||
|
const void* const* values,
|
||||||
|
const size_t* lengths,
|
||||||
|
int num_values) {
|
||||||
|
std::vector<StringPiece> v(num_values);
|
||||||
|
for (int i = 0; i < num_values; ++i) {
|
||||||
|
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(attr_name, v);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFloatList(const char* attr_name,
|
||||||
|
const float* values,
|
||||||
|
int num_values) {
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const float>(values, num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrIntList(const char* attr_name,
|
||||||
|
const int64_t* values,
|
||||||
|
int num_values) {
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const int64>(
|
||||||
|
reinterpret_cast<const int64*>(values), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrTypeList(const char* attr_name,
|
||||||
|
const TF_DataType* values,
|
||||||
|
int num_values) {
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const DataType>(
|
||||||
|
reinterpret_cast<const DataType*>(values), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrBoolList(const char* attr_name,
|
||||||
|
const unsigned char* values,
|
||||||
|
int num_values) {
|
||||||
|
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||||
|
for (int i = 0; i < num_values; ++i) {
|
||||||
|
b[i] = values[i];
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrShapeList(const char* attr_name,
|
||||||
|
const int64_t** dims,
|
||||||
|
const int* num_dims,
|
||||||
|
int num_values) {
|
||||||
|
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||||
|
for (int i = 0; i < num_values; ++i) {
|
||||||
|
const auto num_dims_i = num_dims[i];
|
||||||
|
|
||||||
|
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||||
|
num_dims_i, " dimensions which is over the limit of ",
|
||||||
|
TensorShape::MaxDimensions(), "."));
|
||||||
|
}
|
||||||
|
if (num_dims_i < 0) {
|
||||||
|
proto[i].set_unknown_rank(true);
|
||||||
|
} else {
|
||||||
|
const int64_t* dims_i = dims[i];
|
||||||
|
auto proto_i = &proto[i];
|
||||||
|
for (int d = 0; d < num_dims_i; ++d) {
|
||||||
|
proto_i->add_dim()->set_size(dims_i[d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
|
||||||
|
const TFE_Op** value,
|
||||||
|
int num_values) {
|
||||||
|
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||||
|
for (int i = 0; i < num_values; i++) {
|
||||||
|
auto value_operation =
|
||||||
|
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
|
||||||
|
funcs[i].set_name(value_operation->operation_.Name());
|
||||||
|
value_operation->operation_.Attrs().FillAttrValueMap(
|
||||||
|
funcs[i].mutable_attr());
|
||||||
|
}
|
||||||
|
operation_.MutableAttrs()->Set(
|
||||||
|
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const OpDef* OperationInterface::GetOpDef(Status* status) {
|
||||||
|
const tensorflow::OpDef* op_def = operation_.OpDef();
|
||||||
|
if (op_def) return op_def;
|
||||||
|
*status = OpDefForOp(Name(), &op_def);
|
||||||
|
return op_def;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::InputLength(const char* input_name, int* length) {
|
||||||
|
Status status;
|
||||||
|
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
AttrValueMap attrs;
|
||||||
|
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||||
|
NameRangeMap name_ranges;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
|
||||||
|
auto iter = name_ranges.find(input_name);
|
||||||
|
if (iter == name_ranges.end()) {
|
||||||
|
return errors::InvalidArgument("Input '", input_name, "' not found");
|
||||||
|
}
|
||||||
|
*length = iter->second.second - iter->second.first;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||||
|
Status status;
|
||||||
|
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
AttrValueMap attrs;
|
||||||
|
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||||
|
NameRangeMap name_ranges;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
|
||||||
|
auto iter = name_ranges.find(output_name);
|
||||||
|
if (iter == name_ranges.end()) {
|
||||||
|
return errors::InvalidArgument("Output '", output_name, "' not found");
|
||||||
|
}
|
||||||
|
*length = iter->second.second - iter->second.first;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::AddInput(
|
||||||
|
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||||
|
TensorHandle* h =
|
||||||
|
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||||
|
operation_.AddInput(h);
|
||||||
|
return operation_.MaybeInferSingleInputAttrs(h);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::AddInputList(
|
||||||
|
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||||
|
inputs) {
|
||||||
|
for (auto& input : inputs) {
|
||||||
|
TensorHandle* h =
|
||||||
|
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||||
|
operation_.AddInput(h);
|
||||||
|
}
|
||||||
|
return operation_.InferInputListAttrs(inputs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::Execute(
|
||||||
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||||
|
int* num_retvals) {
|
||||||
|
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
|
||||||
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
|
retvals->at(i).reset(
|
||||||
|
new tensorflow::TensorHandleInterface(handle_retvals[i]));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetCancellationManager(
|
||||||
|
TFE_CancellationManager* cancellation_manager) {
|
||||||
|
operation_.SetCancellationManager(
|
||||||
|
&cancellation_manager->cancellation_manager);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OperationInterface::SetUseXla(bool enable) {
|
||||||
|
operation_.SetUseXla(enable);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
188
tensorflow/c/eager/operation_interface.h
Normal file
188
tensorflow/c/eager/operation_interface.h
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/container/fixed_array.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
|
|
||||||
|
// Abstract interface to an operation.
|
||||||
|
class AbstractOperationInterface {
|
||||||
|
public:
|
||||||
|
virtual ~AbstractOperationInterface() {}
|
||||||
|
|
||||||
|
virtual void Clear() = 0;
|
||||||
|
virtual tensorflow::Status Reset(const char* op,
|
||||||
|
const char* raw_device_name) = 0;
|
||||||
|
|
||||||
|
virtual const tensorflow::string& Name() const = 0;
|
||||||
|
virtual const tensorflow::string& DeviceName() const = 0;
|
||||||
|
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
|
||||||
|
|
||||||
|
virtual tensorflow::Status AddInput(
|
||||||
|
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||||
|
virtual tensorflow::Status AddInputList(
|
||||||
|
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||||
|
inputs) = 0;
|
||||||
|
virtual tensorflow::Status Execute(
|
||||||
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||||
|
int* num_retvals) = 0;
|
||||||
|
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||||
|
|
||||||
|
virtual tensorflow::Status SetAttrString(const char* attr_name,
|
||||||
|
const char* data, size_t length) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrInt(const char* attr_name,
|
||||||
|
int64_t value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
|
||||||
|
float value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrType(const char* attr_name,
|
||||||
|
TF_DataType value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrShape(const char* attr_name,
|
||||||
|
const int64_t* dims,
|
||||||
|
const int num_dims) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFunction(
|
||||||
|
const char* attr_name,
|
||||||
|
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
|
||||||
|
const char* value,
|
||||||
|
size_t length) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
|
||||||
|
TF_Tensor* tensor) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
|
||||||
|
const void* const* values,
|
||||||
|
const size_t* lengths,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
|
||||||
|
const float* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
|
||||||
|
const int64_t* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
|
||||||
|
const TF_DataType* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
|
||||||
|
const unsigned char* values,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
|
||||||
|
const int64_t** dims,
|
||||||
|
const int* num_dims,
|
||||||
|
int num_values) = 0;
|
||||||
|
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
|
||||||
|
const TFE_Op** value,
|
||||||
|
int num_values) = 0;
|
||||||
|
|
||||||
|
virtual tensorflow::Status InputLength(const char* input_name,
|
||||||
|
int* length) = 0;
|
||||||
|
virtual tensorflow::Status OutputLength(const char* output_name,
|
||||||
|
int* length) = 0;
|
||||||
|
|
||||||
|
// Experimental
|
||||||
|
virtual tensorflow::Status SetUseXla(bool enable) {
|
||||||
|
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
|
||||||
|
}
|
||||||
|
virtual tensorflow::Status SetCancellationManager(
|
||||||
|
TFE_CancellationManager* cancellation_manager) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetCancellationManager not implemented");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class OpDef;
|
||||||
|
|
||||||
|
class OperationInterface : public AbstractOperationInterface {
|
||||||
|
public:
|
||||||
|
explicit OperationInterface(TFE_Context* ctx);
|
||||||
|
~OperationInterface() override{};
|
||||||
|
|
||||||
|
void Clear() override { operation_.Clear(); }
|
||||||
|
Status Reset(const char* op, const char* raw_device_name) override {
|
||||||
|
return operation_.Reset(op, raw_device_name, false, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
const string& Name() const override { return operation_.Name(); }
|
||||||
|
const string& DeviceName() const override;
|
||||||
|
Status SetDeviceName(const char* name) override;
|
||||||
|
|
||||||
|
Status AddInput(
|
||||||
|
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
|
||||||
|
Status AddInputList(
|
||||||
|
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||||
|
inputs) override;
|
||||||
|
Status Execute(
|
||||||
|
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||||
|
int* num_retvals) override;
|
||||||
|
const tensorflow::OpDef* OpDef() const override {
|
||||||
|
return operation_.OpDef();
|
||||||
|
};
|
||||||
|
|
||||||
|
Status SetAttrString(const char* attr_name, const char* data,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||||
|
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||||
|
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||||
|
Status SetAttrType(const char* attr_name, TF_DataType value) override;
|
||||||
|
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||||
|
const int num_dims) override;
|
||||||
|
Status SetAttrFunction(
|
||||||
|
const char* attr_name,
|
||||||
|
const std::unique_ptr<AbstractOperationInterface>& value) override;
|
||||||
|
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
|
||||||
|
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||||
|
const size_t* lengths, int num_values) override;
|
||||||
|
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||||
|
const int* num_dims, int num_values) override;
|
||||||
|
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
|
||||||
|
int num_values) override;
|
||||||
|
|
||||||
|
Status InputLength(const char* input_name, int* length) override;
|
||||||
|
Status OutputLength(const char* output_name, int* length) override;
|
||||||
|
|
||||||
|
Status SetUseXla(bool enable) override;
|
||||||
|
Status SetCancellationManager(
|
||||||
|
TFE_CancellationManager* cancellation_manager) override;
|
||||||
|
|
||||||
|
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||||
|
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||||
|
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||||
|
|
||||||
|
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||||
|
EagerOperation operation_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
@ -284,7 +284,7 @@ class ForwardAccumulator {
|
|||||||
// Temporarily push or pop transient state for this accumulator.
|
// Temporarily push or pop transient state for this accumulator.
|
||||||
//
|
//
|
||||||
// Allows an accumulator which is currently processing an operation to
|
// Allows an accumulator which is currently processing an operation to
|
||||||
// temporarily reset its state. Without pushing and poping, accumulators
|
// temporarily reset its state. Without pushing and popping, accumulators
|
||||||
// ignore operations executed as a direct result of their own jvp
|
// ignore operations executed as a direct result of their own jvp
|
||||||
// computations.
|
// computations.
|
||||||
void PushState() { call_state_.emplace(nullptr, false); }
|
void PushState() { call_state_.emplace(nullptr, false); }
|
||||||
|
@ -55,6 +55,14 @@ class AbstractTensorHandleInterface {
|
|||||||
|
|
||||||
// Return a copy of the handle.
|
// Return a copy of the handle.
|
||||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||||
|
|
||||||
|
// Maintain mirror tensors for any implicit copies to local devices. This
|
||||||
|
// setting is offered on a per tensor handle basis to avoid potential memory
|
||||||
|
// over utilization due to holding on to mirrors as well as the original
|
||||||
|
// tensor. Note this setting overrides the context mirroring policy whereby if
|
||||||
|
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
|
||||||
|
// this tensor.
|
||||||
|
virtual void EnableImplicitMirroring() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -77,6 +85,8 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
|
|||||||
|
|
||||||
AbstractTensorHandleInterface* Copy() override;
|
AbstractTensorHandleInterface* Copy() override;
|
||||||
|
|
||||||
|
void EnableImplicitMirroring() override;
|
||||||
|
|
||||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||||
// use cases.
|
// use cases.
|
||||||
TensorHandle* Handle() { return handle_; }
|
TensorHandle* Handle() { return handle_; }
|
||||||
|
@ -18,37 +18,28 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Core TensorFlow depends on this, this will be included in main library
|
|
||||||
cc_library(
|
|
||||||
name = "filesystem_interface_impl",
|
|
||||||
srcs = ["filesystem_interface.cc"],
|
|
||||||
hdrs = ["filesystem_interface.h"],
|
|
||||||
deps = [
|
|
||||||
":modular_filesystem",
|
|
||||||
"//tensorflow/c:tf_file_statistics",
|
|
||||||
"//tensorflow/c:tf_status",
|
|
||||||
"//tensorflow/c:tf_status_internal",
|
|
||||||
"//tensorflow/core:ptr_util",
|
|
||||||
"//tensorflow/core/platform:env",
|
|
||||||
"//tensorflow/core/platform:logging",
|
|
||||||
"//tensorflow/core/platform:strcat",
|
|
||||||
"//tensorflow/core/platform:stringpiece",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Core TensorFlow depends on this, will be included in main library
|
# Core TensorFlow depends on this, will be included in main library
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "modular_filesystem",
|
name = "modular_filesystem",
|
||||||
srcs = ["modular_filesystem.cc"],
|
srcs = [
|
||||||
hdrs = ["modular_filesystem.h"],
|
"modular_filesystem.cc",
|
||||||
|
"modular_filesystem_registration.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"modular_filesystem.h",
|
||||||
|
"modular_filesystem_registration.h",
|
||||||
|
],
|
||||||
|
# TODO(mihaimaruseac): Visibility should be more restrictive once we
|
||||||
|
# convert to modular filesystems everywhere
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":filesystem_interface",
|
":filesystem_interface",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/c:tf_status_internal",
|
||||||
"//tensorflow/core:ptr_util",
|
"//tensorflow/core:ptr_util",
|
||||||
"//tensorflow/core/platform:env",
|
"//tensorflow/core/platform:env",
|
||||||
"//tensorflow/core/platform:strcat",
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,16 +54,12 @@ tf_cc_test(
|
|||||||
"notap", # b/139060984, requires implementing modular support for Google filesystem
|
"notap", # b/139060984, requires implementing modular support for Google filesystem
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":filesystem_interface_impl",
|
":modular_filesystem",
|
||||||
"//tensorflow/c:tf_status",
|
|
||||||
"//tensorflow/c:tf_status_internal",
|
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core/lib/io:path",
|
"//tensorflow/core/lib/io:path",
|
||||||
"//tensorflow/core/platform:env",
|
"//tensorflow/core/platform:env",
|
||||||
"//tensorflow/core/platform:error",
|
"//tensorflow/core/platform:error",
|
||||||
"//tensorflow/core/platform:stacktrace_handler",
|
"//tensorflow/core/platform:stacktrace_handler",
|
||||||
"//tensorflow/core/platform:str_util",
|
|
||||||
"//tensorflow/core/platform:strcat",
|
|
||||||
"//tensorflow/core/platform:test",
|
"//tensorflow/core/platform:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,366 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
|
||||||
#include "tensorflow/c/tf_status_internal.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
|
||||||
#include "tensorflow/core/platform/stringpiece.h"
|
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
|
||||||
|
|
||||||
/// This translation unit is linked in core TensorFlow and provides the
|
|
||||||
/// functionality needed for plugin registration to check ABI/API compatibility,
|
|
||||||
/// to ensure required methods are present, to ensure plugins are not allowed to
|
|
||||||
/// change functionality after being loaded and to register the filesystems
|
|
||||||
/// provided by a plugin. Consult the header file for more information about
|
|
||||||
/// how this is achieved.
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Checks if the plugin and core ABI numbers match, filling in `status`.
|
|
||||||
//
|
|
||||||
// If the numbers don't match, plugin cannot be loaded.
|
|
||||||
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (pluginABI != coreABI) {
|
|
||||||
TF_SetStatus(
|
|
||||||
status, TF_FAILED_PRECONDITION,
|
|
||||||
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
|
||||||
" operations doesn't match expected core ABI (",
|
|
||||||
coreABI, "). Plugin cannot be loaded.")
|
|
||||||
.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the plugin and core ABI numbers match, for all operations.
|
|
||||||
//
|
|
||||||
// If the numbers don't match, plugin cannot be loaded.
|
|
||||||
//
|
|
||||||
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
|
|
||||||
static bool CheckABI(
|
|
||||||
int plugin_filesystem_ops_ABI,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
int plugin_random_access_file_ops_ABI,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
int plugin_writable_file_ops_ABI,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
|
|
||||||
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
|
|
||||||
"filesystem", status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (plugin_random_access_file_ops != nullptr &&
|
|
||||||
!CheckABIHelper(plugin_random_access_file_ops_ABI,
|
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
|
|
||||||
status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (plugin_writable_file_ops != nullptr &&
|
|
||||||
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
|
|
||||||
"writable file", status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (plugin_read_only_memory_region_ops != nullptr &&
|
|
||||||
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
|
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
|
||||||
"read only memory region", status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the plugin and core API numbers match, logging mismatches.
|
|
||||||
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
|
|
||||||
if (plugin_API != core_API) {
|
|
||||||
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
|
||||||
<< " operations doesn't match expected core API (" << core_API
|
|
||||||
<< "). Plugin will be loaded but functionality might be missing.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the plugin and core API numbers match, for all operations.
|
|
||||||
//
|
|
||||||
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
|
||||||
static void CheckAPI(
|
|
||||||
int plugin_filesystem_ops_API,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
int plugin_random_access_file_ops_API,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
int plugin_writable_file_ops_API,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
int plugin_read_only_memory_region_ops_API) {
|
|
||||||
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
|
|
||||||
"filesystem");
|
|
||||||
|
|
||||||
if (plugin_random_access_file_ops != nullptr)
|
|
||||||
CheckAPIHelper(plugin_random_access_file_ops_API,
|
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
|
|
||||||
|
|
||||||
if (plugin_writable_file_ops != nullptr)
|
|
||||||
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
|
|
||||||
"writable file");
|
|
||||||
|
|
||||||
if (plugin_read_only_memory_region_ops != nullptr)
|
|
||||||
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
|
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_API,
|
|
||||||
"read only memory region");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the filesystem operations supplied by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without operations");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->init == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `init` operation");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the random access file operations supplied by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
// We allow filesystems where files can only be written to (from TF code)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation on "
|
|
||||||
"random access files");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the writable file operations supplied by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
// We allow read-only filesystems
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation on "
|
|
||||||
"writable files");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the read only memory region operations given by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
// read only memory region support is always optional
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation on "
|
|
||||||
"read only memory regions");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->data == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `data` operation on "
|
|
||||||
"read only memory regions");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->length == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `length` operation on "
|
|
||||||
"read only memory regions");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the operations supplied by the plugin.
|
|
||||||
//
|
|
||||||
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
|
|
||||||
// each individual function table and then checks that the function table for a
|
|
||||||
// specific file type exists if the plugin offers support for creating that
|
|
||||||
// type of files.
|
|
||||||
static bool Validate(
|
|
||||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
|
|
||||||
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
|
|
||||||
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
|
|
||||||
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
|
|
||||||
|
|
||||||
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
|
|
||||||
plugin_random_access_file_ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Filesystem allows creation of random access files but no "
|
|
||||||
"operations on them have been supplied.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
|
|
||||||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
|
|
||||||
plugin_writable_file_ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Filesystem allows creation of writable files but no "
|
|
||||||
"operations on them have been supplied.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
|
||||||
plugin_read_only_memory_region_ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Filesystem allows creation of readonly memory regions but no "
|
|
||||||
"operations on them have been supplied.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copies a function table from plugin memory space to core memory space.
|
|
||||||
//
|
|
||||||
// This has three benefits:
|
|
||||||
// * allows having newer plugins than the current core TensorFlow: the
|
|
||||||
// additional entries in the plugin's table are just discarded;
|
|
||||||
// * allows having older plugins than the current core TensorFlow (though
|
|
||||||
// we are still warning users): the entries that core TensorFlow expects
|
|
||||||
// but plugins didn't provide will be set to `nullptr` values and core
|
|
||||||
// TensorFlow will know to not call these on behalf of users;
|
|
||||||
// * increased security as plugins will not be able to alter function table
|
|
||||||
// after loading up. Thus, malicious plugins can't alter functionality to
|
|
||||||
// probe for gadgets inside core TensorFlow. We can even protect the area
|
|
||||||
// of memory where the copies reside to not allow any more writes to it
|
|
||||||
// after all copies are created.
|
|
||||||
template <typename T>
|
|
||||||
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
|
||||||
size_t plugin_size) {
|
|
||||||
if (plugin_ops == nullptr) return nullptr;
|
|
||||||
|
|
||||||
size_t copy_size = sizeof(T);
|
|
||||||
if (plugin_size < copy_size) {
|
|
||||||
copy_size = plugin_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto core_ops = tensorflow::MakeUnique<T>();
|
|
||||||
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
|
|
||||||
return core_ops;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
void RegisterFilesystemPlugin(
|
|
||||||
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
|
|
||||||
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
|
|
||||||
int plugin_random_access_file_ops_API,
|
|
||||||
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
|
|
||||||
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
|
|
||||||
int plugin_read_only_memory_region_ops_ABI,
|
|
||||||
int plugin_read_only_memory_region_ops_API,
|
|
||||||
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
|
|
||||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (scheme == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
|
||||||
"`scheme` argument must not be `nullptr`.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ABI numbers must match exactly for plugin to be loaded
|
|
||||||
if (!tensorflow::CheckABI(
|
|
||||||
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
|
|
||||||
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
|
|
||||||
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
|
|
||||||
plugin_read_only_memory_region_ops_ABI, status)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// API numbers should match but mismatch doesn't block plugin load
|
|
||||||
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
|
|
||||||
plugin_random_access_file_ops_API,
|
|
||||||
plugin_writable_file_ops, plugin_writable_file_ops_API,
|
|
||||||
plugin_read_only_memory_region_ops,
|
|
||||||
plugin_read_only_memory_region_ops_API);
|
|
||||||
|
|
||||||
// Plugin can only be loaded if all supplied ops are valid
|
|
||||||
if (!tensorflow::Validate(plugin_filesystem_ops,
|
|
||||||
plugin_random_access_file_ops,
|
|
||||||
plugin_writable_file_ops,
|
|
||||||
plugin_read_only_memory_region_ops, status)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all the function tables to core TensorFlow memory space
|
|
||||||
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
|
|
||||||
plugin_filesystem_ops, plugin_filesystem_ops_size);
|
|
||||||
auto core_random_access_file_ops =
|
|
||||||
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
|
|
||||||
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
|
|
||||||
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
|
|
||||||
plugin_writable_file_ops, plugin_writable_file_ops_size);
|
|
||||||
auto core_read_only_memory_region_ops =
|
|
||||||
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
|
||||||
plugin_read_only_memory_region_ops,
|
|
||||||
plugin_read_only_memory_region_ops_size);
|
|
||||||
|
|
||||||
// Initialize the opaque filesystem structure
|
|
||||||
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
|
||||||
core_filesystem_ops->init(filesystem.get(), status);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
core_filesystem_ops->cleanup(filesystem.get());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register new filesystem
|
|
||||||
status->status = tensorflow::Env::Default()->RegisterFileSystem(
|
|
||||||
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
|
||||||
std::move(filesystem), std::move(core_filesystem_ops),
|
|
||||||
std::move(core_random_access_file_ops),
|
|
||||||
std::move(core_writable_file_ops),
|
|
||||||
std::move(core_read_only_memory_region_ops)));
|
|
||||||
}
|
|
@ -56,7 +56,7 @@ extern "C" {
|
|||||||
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
|
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
|
||||||
/// pointed to by the `void*` members is always owned by the plugin. The plugin
|
/// pointed to by the `void*` members is always owned by the plugin. The plugin
|
||||||
/// will provide functions to call to allocate and deallocate this data (see
|
/// will provide functions to call to allocate and deallocate this data (see
|
||||||
/// next section) and core TensorFlow ensures to call these at the proper time.
|
/// next sections) and core TensorFlow ensures to call these at the proper time.
|
||||||
///
|
///
|
||||||
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
|
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
|
||||||
/// TensorFlow will never touch the `void*` wrapped by these structures, except
|
/// TensorFlow will never touch the `void*` wrapped by these structures, except
|
||||||
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
|
|||||||
/// If `statuses` is not null, plugins must fill each element with detailed
|
/// If `statuses` is not null, plugins must fill each element with detailed
|
||||||
/// status for each file, as if calling `path_exists` on each one. Core
|
/// status for each file, as if calling `path_exists` on each one. Core
|
||||||
/// TensorFlow initializes the `statuses` array and plugins must use
|
/// TensorFlow initializes the `statuses` array and plugins must use
|
||||||
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
|
/// `TF_SetStatus` to set each element instead of directly assigning.
|
||||||
///
|
///
|
||||||
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
|
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
|
||||||
/// `path_exists`.
|
/// `path_exists`.
|
||||||
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
|
|||||||
///
|
///
|
||||||
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
|
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
|
||||||
///
|
///
|
||||||
|
/// The allocation and freeing of memory must happen via the functions sent to
|
||||||
|
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||||
|
/// structure in Section 4).
|
||||||
|
///
|
||||||
/// This function will be called by core TensorFlow to clean up all path
|
/// This function will be called by core TensorFlow to clean up all path
|
||||||
/// arguments for all other methods in the filesystem API.
|
/// arguments for all other methods in the filesystem API.
|
||||||
///
|
///
|
||||||
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
|
|||||||
/// In case of error, plugins must set `status` to a value different than
|
/// In case of error, plugins must set `status` to a value different than
|
||||||
/// `TF_OK`, free memory allocated for `entries` and return -1.
|
/// `TF_OK`, free memory allocated for `entries` and return -1.
|
||||||
///
|
///
|
||||||
|
/// The allocation and freeing of memory must happen via the functions sent to
|
||||||
|
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||||
|
/// structure in Section 4).
|
||||||
|
///
|
||||||
/// Plugins:
|
/// Plugins:
|
||||||
/// * Must set `status` to `TF_OK` if all children were returned.
|
/// * Must set `status` to `TF_OK` if all children were returned.
|
||||||
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
|
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
|
||||||
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
|
|||||||
/// different than `TF_OK`, free any memory that might have been allocated for
|
/// different than `TF_OK`, free any memory that might have been allocated for
|
||||||
/// `entries` and return -1.
|
/// `entries` and return -1.
|
||||||
///
|
///
|
||||||
|
/// The allocation and freeing of memory must happen via the functions sent to
|
||||||
|
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||||
|
/// structure in Section 4).
|
||||||
|
///
|
||||||
/// Plugins:
|
/// Plugins:
|
||||||
/// * Must set `status` to `TF_OK` if all matches were returned.
|
/// * Must set `status` to `TF_OK` if all matches were returned.
|
||||||
/// * Might use any other error value for `status` to signal other errors.
|
/// * Might use any other error value for `status` to signal other errors.
|
||||||
@ -736,95 +748,132 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
|
|||||||
/// SECTION 4. Plugin registration and initialization
|
/// SECTION 4. Plugin registration and initialization
|
||||||
/// ----------------------------------------------------------------------------
|
/// ----------------------------------------------------------------------------
|
||||||
///
|
///
|
||||||
/// In this section we define two functions:
|
/// In this section we define the API used by core TensorFlow to initialize a
|
||||||
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
|
/// filesystem provided by a plugin. That is, we define the following:
|
||||||
/// be called by core TensorFlow when the filesystem plugin is loaded;
|
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
|
||||||
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
|
/// it will be called by core TensorFlow when the filesystem plugin is
|
||||||
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
|
/// loaded;
|
||||||
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
|
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
|
||||||
|
/// plugins and core TensorFlow about the operations provided and metadata;
|
||||||
|
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
|
||||||
|
/// collects information about all the file schemes that the plugin provides
|
||||||
|
/// support for, as well as about the plugin's memory handling routines;
|
||||||
|
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
|
||||||
|
/// their `TF_InitPlugin` to record the versioning information the plugins
|
||||||
|
/// are compiled against.
|
||||||
///
|
///
|
||||||
/// The `TF_InitPlugin` function is used by plugins to set up the data
|
/// The `TF_InitPlugin` function is used by plugins to set up the data
|
||||||
/// structures that implement this interface, as presented in Section 2.
|
/// structures that implement this interface, as presented in Section 2. In
|
||||||
///
|
/// order to not have plugin shared objects call back symbols defined in core
|
||||||
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
|
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
|
||||||
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
|
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
|
||||||
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
|
/// metadata and setting up all the supported operations and the URI schemes
|
||||||
/// 2. If the API numbers are mismatched, we warn the user and continue
|
/// that are supported).
|
||||||
/// loading the plugin.
|
|
||||||
/// 3. If any required operation is missing, we stop loading the plugin.
|
|
||||||
///
|
|
||||||
/// If all these checks succeed, we copy the plugin operations to a different
|
|
||||||
/// memory location so that core TensorFlow has the guarantee that they won't be
|
|
||||||
/// changed by plugins at a later time. Finally, we initialize the opaque
|
|
||||||
/// pointer of `TF_Filesystem` by calling the required `init` function of
|
|
||||||
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
|
|
||||||
|
|
||||||
// Initializes a TensorFlow plugin.
|
/// This structure incorporates the operations defined in Section 2 and the
|
||||||
//
|
/// metadata defined in section 3, allowing plugins to define different ops
|
||||||
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
/// for different URI schemes.
|
||||||
//
|
///
|
||||||
// Filesystem plugins can be loaded on demand by users via
|
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
|
||||||
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
|
||||||
// paths (although this has a security risk if two plugins register for the
|
/// must be "". The scheme must never be `nullptr`.
|
||||||
// same filesystem and the malicious one loads before the legimitate one -
|
///
|
||||||
// but we consider this to be something that users should care about and
|
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
|
||||||
// manage themselves). In both of these cases, core TensorFlow looks for
|
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
|
||||||
// the `TF_InitPlugin` symbol and calls that function.
|
/// TensorFlow uses the information present in this to initialize filesystems
|
||||||
//
|
/// for the URI schemes that the plugin requests.
|
||||||
// A plugin is loaded only if this `status` is `TF_OK` after the call.
|
///
|
||||||
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
|
/// All pointers defined in this structure point to memory allocated by the DSO
|
||||||
|
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
|
||||||
|
///
|
||||||
|
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||||
|
/// must not change! In the unlikely case that a new type of file needs to be
|
||||||
|
/// supported, add the new ops and metadata at the end of the structure.
|
||||||
|
typedef struct TF_FilesystemPluginOps {
|
||||||
|
char* scheme;
|
||||||
|
int filesystem_ops_abi;
|
||||||
|
int filesystem_ops_api;
|
||||||
|
size_t filesystem_ops_size;
|
||||||
|
TF_FilesystemOps* filesystem_ops;
|
||||||
|
int random_access_file_ops_abi;
|
||||||
|
int random_access_file_ops_api;
|
||||||
|
size_t random_access_file_ops_size;
|
||||||
|
TF_RandomAccessFileOps* random_access_file_ops;
|
||||||
|
int writable_file_ops_abi;
|
||||||
|
int writable_file_ops_api;
|
||||||
|
size_t writable_file_ops_size;
|
||||||
|
TF_WritableFileOps* writable_file_ops;
|
||||||
|
int read_only_memory_region_ops_abi;
|
||||||
|
int read_only_memory_region_ops_api;
|
||||||
|
size_t read_only_memory_region_ops_size;
|
||||||
|
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
|
||||||
|
} TF_FilesystemPluginOps;
|
||||||
|
|
||||||
/// Registers a filesystem plugin so that core TensorFlow can use it.
|
/// This structure gathers together all the operations provided by the plugin.
|
||||||
///
|
///
|
||||||
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
|
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
|
||||||
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
|
|
||||||
///
|
///
|
||||||
/// Arguments (grouped by category):
|
/// Since memory that is allocated by the DSO gets transferred to core
|
||||||
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
|
/// TensorFlow, we need to provide a way for the allocation and deallocation to
|
||||||
/// * `..API`: API compatibility numbers (see Section 3.).
|
/// match. This is why this structure also defines `plugin_memory_allocate` and
|
||||||
/// * `..Size`: Sizes of the operation tables (see Section 3.).
|
/// `plugin_memory_free` members.
|
||||||
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
|
|
||||||
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
|
|
||||||
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
|
|
||||||
/// must be "". Must never be `nullptr`.
|
|
||||||
/// * `..Ops`: The function tables provided by the plugin. Owned by the
|
|
||||||
/// plugin, but core TensorFlow makes a copy of these.
|
|
||||||
/// * `status`: The output variable for representing success/failure.
|
|
||||||
///
|
///
|
||||||
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
|
/// All memory allocated by the plugin that will be owned by core TensorFlow
|
||||||
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
|
/// must be allocated using the allocator in this structure. Core TensorFlow
|
||||||
/// `status` means that plugin failed to load properly and as such the
|
/// will use the deallocator to free this memory once it no longer needs it.
|
||||||
/// operations it provides cannot be used at all (i.e., core TensorFlow will
|
///
|
||||||
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
|
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||||
/// values).
|
/// must not change! In the unlikely case that new global operations must be
|
||||||
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
|
/// provided, add them at the end of the structure.
|
||||||
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
|
typedef struct TF_FilesystemPluginInfo {
|
||||||
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
|
size_t num_schemes;
|
||||||
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
|
TF_FilesystemPluginOps* ops;
|
||||||
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
|
void* (*plugin_memory_allocate)(size_t size);
|
||||||
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
|
void (*plugin_memory_free)(void* ptr);
|
||||||
int pluginReadOnlyMemoryRegionOpsAPI,
|
} TF_FilesystemPluginInfo;
|
||||||
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
|
|
||||||
const TF_FilesystemOps* pluginFilesystemOps,
|
|
||||||
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
|
|
||||||
const TF_WritableFileOps* pluginWritableFileOps,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
|
/// Convenience function for setting the versioning metadata.
|
||||||
/// Plugins should prefer using this macro instead of a direct call.
|
///
|
||||||
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
|
/// The argument is guaranteed to not be `nullptr`.
|
||||||
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
|
///
|
||||||
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
|
/// We want this to be defined in the plugin's memory space and we guarantee
|
||||||
RegisterFilesystemPlugin( \
|
/// that core TensorFlow will never call this.
|
||||||
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
|
static inline void TF_SetFilesystemVersionMetadata(
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
|
TF_FilesystemPluginOps* ops) {
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
|
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
|
||||||
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
|
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
|
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
|
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
|
||||||
pluginRandomAccessFileOps, pluginWritableFileOps, \
|
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
|
||||||
pluginReadOnlyMemoryRegionOps, status)
|
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
|
||||||
|
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
|
||||||
|
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
|
||||||
|
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
|
||||||
|
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
|
||||||
|
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
|
||||||
|
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initializes a TensorFlow plugin.
|
||||||
|
///
|
||||||
|
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
||||||
|
///
|
||||||
|
/// Filesystem plugins can be loaded on demand by users via
|
||||||
|
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
||||||
|
/// paths (although this has a security risk if two plugins register for the
|
||||||
|
/// same filesystem and the malicious one loads before the legimitate one -
|
||||||
|
/// but we consider this to be something that users should care about and
|
||||||
|
/// manage themselves). In both of these cases, core TensorFlow looks for
|
||||||
|
/// the `TF_InitPlugin` symbol and calls this function.
|
||||||
|
///
|
||||||
|
/// For every filesystem URI scheme that this plugin supports, the plugin must
|
||||||
|
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
|
||||||
|
/// `TF_SetFilesystemVersionMetadata` for that entry.
|
||||||
|
///
|
||||||
|
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
|
||||||
|
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
|
||||||
|
/// freed in a compatible way.
|
||||||
|
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // end extern "C"
|
} // end extern "C"
|
||||||
|
@ -18,11 +18,10 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/file_system_helper.h"
|
#include "tensorflow/core/platform/file_system_helper.h"
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
|
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
|
||||||
@ -165,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
|
|||||||
|
|
||||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||||
std::string translated_name = TranslateName(dir);
|
std::string translated_name = TranslateName(dir);
|
||||||
char** children;
|
// Note that `children` is allocated by the plugin and freed by core
|
||||||
|
// TensorFlow, so we need to use `plugin_memory_free_` here.
|
||||||
|
char** children = nullptr;
|
||||||
const int num_children =
|
const int num_children =
|
||||||
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
|
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
|
||||||
plugin_status.get());
|
plugin_status.get());
|
||||||
if (num_children >= 0) {
|
if (num_children >= 0) {
|
||||||
for (int i = 0; i < num_children; i++) {
|
for (int i = 0; i < num_children; i++) {
|
||||||
result->push_back(std::string(children[i]));
|
result->push_back(std::string(children[i]));
|
||||||
free(children[i]);
|
plugin_memory_free_(children[i]);
|
||||||
}
|
}
|
||||||
free(children);
|
plugin_memory_free_(children);
|
||||||
}
|
}
|
||||||
|
|
||||||
return StatusFromTF_Status(plugin_status.get());
|
return StatusFromTF_Status(plugin_status.get());
|
||||||
@ -186,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
|
|||||||
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
|
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
|
||||||
|
|
||||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||||
char** matches;
|
// Note that `matches` is allocated by the plugin and freed by core
|
||||||
|
// TensorFlow, so we need to use `plugin_memory_free_` here.
|
||||||
|
char** matches = nullptr;
|
||||||
const int num_matches = ops_->get_matching_paths(
|
const int num_matches = ops_->get_matching_paths(
|
||||||
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
|
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
|
||||||
if (num_matches >= 0) {
|
if (num_matches >= 0) {
|
||||||
for (int i = 0; i < num_matches; i++) {
|
for (int i = 0; i < num_matches; i++) {
|
||||||
result->push_back(std::string(matches[i]));
|
result->push_back(std::string(matches[i]));
|
||||||
free(matches[i]);
|
plugin_memory_free_(matches[i]);
|
||||||
}
|
}
|
||||||
free(matches);
|
plugin_memory_free_(matches);
|
||||||
}
|
}
|
||||||
|
|
||||||
return StatusFromTF_Status(plugin_status.get());
|
return StatusFromTF_Status(plugin_status.get());
|
||||||
@ -358,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
|
|||||||
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
|
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
|
||||||
|
|
||||||
std::string ret(p);
|
std::string ret(p);
|
||||||
free(p);
|
// Since `p` is allocated by plugin, free it using plugin's method.
|
||||||
|
plugin_memory_free_(p);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,4 +439,26 @@ Status ModularWritableFile::Tell(int64* position) {
|
|||||||
return StatusFromTF_Status(plugin_status.get());
|
return StatusFromTF_Status(plugin_status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status RegisterFilesystemPlugin(const std::string& dso_path) {
|
||||||
|
// Step 1: Load plugin
|
||||||
|
Env* env = Env::Default();
|
||||||
|
void* dso_handle;
|
||||||
|
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
|
||||||
|
|
||||||
|
// Step 2: Load symbol for `TF_InitPlugin`
|
||||||
|
void* dso_symbol;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
|
||||||
|
|
||||||
|
// Step 3: Call `TF_InitPlugin`
|
||||||
|
TF_FilesystemPluginInfo info;
|
||||||
|
memset(&info, 0, sizeof(info));
|
||||||
|
auto TF_InitPlugin =
|
||||||
|
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
|
||||||
|
TF_InitPlugin(&info);
|
||||||
|
|
||||||
|
// Step 4: Do the actual registration
|
||||||
|
return filesystem_registration::RegisterFilesystemPluginImpl(&info);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
|||||||
// TODO(b/143949615): After all filesystems are converted, this file will be
|
// TODO(b/143949615): After all filesystems are converted, this file will be
|
||||||
// moved to core/platform, and this class can become a singleton and replace the
|
// moved to core/platform, and this class can become a singleton and replace the
|
||||||
// need for `Env::Default()`. At that time, we might decide to remove the need
|
// need for `Env::Default()`. At that time, we might decide to remove the need
|
||||||
// for `Env::Default()` altoghether, but that's a different project, not in
|
// for `Env::Default()` altogether, but that's a different project, not in
|
||||||
// scope for now. I'm just mentioning this here as that transition will mean
|
// scope for now. I'm just mentioning this here as that transition will mean
|
||||||
// removal of the registration part from `Env` and adding it here instead: we
|
// removal of the registration part from `Env` and adding it here instead: we
|
||||||
// will need tables to hold for each scheme the function tables that implement
|
// will need tables to hold for each scheme the function tables that implement
|
||||||
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
|
|||||||
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
|
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
|
||||||
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
|
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
|
||||||
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
||||||
read_only_memory_region_ops)
|
read_only_memory_region_ops,
|
||||||
|
std::function<void*(size_t)> plugin_memory_allocate,
|
||||||
|
std::function<void(void*)> plugin_memory_free)
|
||||||
: filesystem_(std::move(filesystem)),
|
: filesystem_(std::move(filesystem)),
|
||||||
ops_(std::move(filesystem_ops)),
|
ops_(std::move(filesystem_ops)),
|
||||||
random_access_file_ops_(std::move(random_access_file_ops)),
|
random_access_file_ops_(std::move(random_access_file_ops)),
|
||||||
writable_file_ops_(std::move(writable_file_ops)),
|
writable_file_ops_(std::move(writable_file_ops)),
|
||||||
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
|
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
|
||||||
|
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
|
||||||
|
plugin_memory_free_(std::move(plugin_memory_free)) {}
|
||||||
|
|
||||||
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
|
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
|
||||||
|
|
||||||
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
|
|||||||
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
|
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
|
||||||
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
||||||
read_only_memory_region_ops_;
|
read_only_memory_region_ops_;
|
||||||
|
std::function<void*(size_t)> plugin_memory_allocate_;
|
||||||
|
std::function<void(void*)> plugin_memory_free_;
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
|
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -156,6 +162,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
|
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Registers a filesystem plugin so that core TensorFlow can use it.
|
||||||
|
Status RegisterFilesystemPlugin(const std::string& dso_path);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_
|
||||||
|
@ -0,0 +1,327 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||||
|
#include "tensorflow/c/tf_status_internal.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Checks that all schemes provided by a plugin are valid.
|
||||||
|
// TODO(mihaimaruseac): More validation could be done here, based on supported
|
||||||
|
// charset, maximum length, etc. Punting it for later.
|
||||||
|
static Status ValidateScheme(const char* scheme) {
|
||||||
|
if (scheme == nullptr)
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Attempted to register filesystem with `nullptr` URI scheme");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core ABI numbers match.
|
||||||
|
//
|
||||||
|
// If the numbers don't match, plugin cannot be loaded.
|
||||||
|
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
|
||||||
|
if (pluginABI != coreABI)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
||||||
|
" operations doesn't match expected core ABI (",
|
||||||
|
coreABI, "). Plugin cannot be loaded."));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core ABI numbers match, for all operations.
|
||||||
|
//
|
||||||
|
// If the numbers don't match, plugin cannot be loaded.
|
||||||
|
//
|
||||||
|
// Uses the simpler `CheckABI(int, int, StringPiece)`.
|
||||||
|
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
|
||||||
|
|
||||||
|
if (ops->random_access_file_ops != nullptr)
|
||||||
|
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
|
||||||
|
TF_RANDOM_ACCESS_FILE_OPS_ABI,
|
||||||
|
"random access file"));
|
||||||
|
|
||||||
|
if (ops->writable_file_ops != nullptr)
|
||||||
|
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
|
||||||
|
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
|
||||||
|
|
||||||
|
if (ops->read_only_memory_region_ops != nullptr)
|
||||||
|
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
|
||||||
|
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
||||||
|
"read only memory region"));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core API numbers match, logging mismatches.
|
||||||
|
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
|
||||||
|
if (plugin_API != core_API) {
|
||||||
|
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
||||||
|
<< " operations doesn't match expected core API (" << core_API
|
||||||
|
<< "). Plugin will be loaded but functionality might be missing.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core API numbers match, for all operations.
|
||||||
|
//
|
||||||
|
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
||||||
|
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
|
||||||
|
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
|
||||||
|
|
||||||
|
if (ops->random_access_file_ops != nullptr)
|
||||||
|
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
|
||||||
|
"random access file");
|
||||||
|
|
||||||
|
if (ops->writable_file_ops != nullptr)
|
||||||
|
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
|
||||||
|
"writable file");
|
||||||
|
|
||||||
|
if (ops->read_only_memory_region_ops != nullptr)
|
||||||
|
CheckAPI(ops->read_only_memory_region_ops_api,
|
||||||
|
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the filesystem operations supplied by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_FilesystemOps* ops) {
|
||||||
|
if (ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without operations");
|
||||||
|
|
||||||
|
if (ops->init == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `init` operation");
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the random access file operations supplied by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
|
||||||
|
if (ops == nullptr) {
|
||||||
|
// We allow filesystems where files can only be written to (from TF code)
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation on random "
|
||||||
|
"access files");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the writable file operations supplied by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_WritableFileOps* ops) {
|
||||||
|
if (ops == nullptr) {
|
||||||
|
// We allow read-only filesystems
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation on writable "
|
||||||
|
"files");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the read only memory region operations given by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
|
||||||
|
if (ops == nullptr) {
|
||||||
|
// read only memory region support is always optional
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation on read "
|
||||||
|
"only memory regions");
|
||||||
|
|
||||||
|
if (ops->data == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `data` operation on read only "
|
||||||
|
"memory regions");
|
||||||
|
|
||||||
|
if (ops->length == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `length` operation on read only "
|
||||||
|
"memory regions");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the operations supplied by the plugin.
|
||||||
|
//
|
||||||
|
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
|
||||||
|
// individual function table and then checks that the function table for a
|
||||||
|
// specific file type exists if the plugin offers support for creating that
|
||||||
|
// type of files.
|
||||||
|
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
|
||||||
|
|
||||||
|
if (ops->filesystem_ops->new_random_access_file != nullptr &&
|
||||||
|
ops->random_access_file_ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Filesystem allows creation of random access files but no "
|
||||||
|
"operations on them have been supplied.");
|
||||||
|
|
||||||
|
if ((ops->filesystem_ops->new_writable_file != nullptr ||
|
||||||
|
ops->filesystem_ops->new_appendable_file != nullptr) &&
|
||||||
|
ops->writable_file_ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Filesystem allows creation of writable files but no "
|
||||||
|
"operations on them have been supplied.");
|
||||||
|
|
||||||
|
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||||
|
ops->read_only_memory_region_ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Filesystem allows creation of readonly memory regions but no "
|
||||||
|
"operations on them have been supplied.");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies a function table from plugin memory space to core memory space.
|
||||||
|
//
|
||||||
|
// This has three benefits:
|
||||||
|
// * allows having newer plugins than the current core TensorFlow: the
|
||||||
|
// additional entries in the plugin's table are just discarded;
|
||||||
|
// * allows having older plugins than the current core TensorFlow (though
|
||||||
|
// we are still warning users): the entries that core TensorFlow expects
|
||||||
|
// but plugins didn't provide will be set to `nullptr` values and core
|
||||||
|
// TensorFlow will know to not call these on behalf of users;
|
||||||
|
// * increased security as plugins will not be able to alter function table
|
||||||
|
// after loading up. Thus, malicious plugins can't alter functionality to
|
||||||
|
// probe for gadgets inside core TensorFlow. We can even protect the area
|
||||||
|
// of memory where the copies reside to not allow any more writes to it
|
||||||
|
// after all copies are created.
|
||||||
|
template <typename T>
|
||||||
|
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
||||||
|
size_t plugin_size) {
|
||||||
|
if (plugin_ops == nullptr) return nullptr;
|
||||||
|
|
||||||
|
size_t copy_size = std::min(plugin_size, sizeof(T));
|
||||||
|
auto core_ops = tensorflow::MakeUnique<T>();
|
||||||
|
memset(core_ops.get(), 0, sizeof(T));
|
||||||
|
memcpy(core_ops.get(), plugin_ops, copy_size);
|
||||||
|
return core_ops;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers one filesystem from the plugin.
|
||||||
|
//
|
||||||
|
// Must be called only with `index` a valid index in `info->ops`.
|
||||||
|
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
|
||||||
|
int index) {
|
||||||
|
// Step 1: Copy all the function tables to core TensorFlow memory space
|
||||||
|
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
|
||||||
|
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
|
||||||
|
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
|
||||||
|
info->ops[index].random_access_file_ops,
|
||||||
|
info->ops[index].random_access_file_ops_size);
|
||||||
|
auto core_writable_file_ops =
|
||||||
|
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
|
||||||
|
info->ops[index].writable_file_ops_size);
|
||||||
|
auto core_read_only_memory_region_ops =
|
||||||
|
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
||||||
|
info->ops[index].read_only_memory_region_ops,
|
||||||
|
info->ops[index].read_only_memory_region_ops_size);
|
||||||
|
|
||||||
|
// Step 2: Initialize the opaque filesystem structure
|
||||||
|
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
||||||
|
TF_Status* c_status = TF_NewStatus();
|
||||||
|
Status status = Status::OK();
|
||||||
|
core_filesystem_ops->init(filesystem.get(), c_status);
|
||||||
|
status = Status(c_status->status);
|
||||||
|
TF_DeleteStatus(c_status);
|
||||||
|
if (!status.ok()) return status;
|
||||||
|
|
||||||
|
// Step 3: Actual registration
|
||||||
|
return Env::Default()->RegisterFileSystem(
|
||||||
|
info->ops[index].scheme,
|
||||||
|
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||||
|
std::move(filesystem), std::move(core_filesystem_ops),
|
||||||
|
std::move(core_random_access_file_ops),
|
||||||
|
std::move(core_writable_file_ops),
|
||||||
|
std::move(core_read_only_memory_region_ops),
|
||||||
|
info->plugin_memory_allocate, info->plugin_memory_free));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers filesystem at `index`, if plugin is providing valid information.
|
||||||
|
//
|
||||||
|
// Extracted to a separate function so that pointers inside `info` are freed
|
||||||
|
// by the caller regardless of whether validation/registration failed or not.
|
||||||
|
//
|
||||||
|
// Must be called only with `index` a valid index in `info->ops`.
|
||||||
|
static Status ValidateAndRegisterFilesystems(
|
||||||
|
const TF_FilesystemPluginInfo* info, int index) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
|
||||||
|
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
|
||||||
|
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
|
||||||
|
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensures that the plugin provides the required memory management operations.
|
||||||
|
static Status ValidatePluginMemoryRoutines(
|
||||||
|
const TF_FilesystemPluginInfo* info) {
|
||||||
|
if (info->plugin_memory_allocate == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Cannot load filesystem plugin which does not provide "
|
||||||
|
"`plugin_memory_allocate`");
|
||||||
|
|
||||||
|
if (info->plugin_memory_free == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Cannot load filesystem plugin which does not provide "
|
||||||
|
"`plugin_memory_free`");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace filesystem_registration {
|
||||||
|
|
||||||
|
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(info));
|
||||||
|
|
||||||
|
// Validate and register all filesystems
|
||||||
|
// Try to register as many filesystems as possible.
|
||||||
|
// Free memory once we no longer need it
|
||||||
|
Status status;
|
||||||
|
for (int i = 0; i < info->num_schemes; i++) {
|
||||||
|
status.Update(ValidateAndRegisterFilesystems(info, i));
|
||||||
|
info->plugin_memory_free(info->ops[i].scheme);
|
||||||
|
info->plugin_memory_free(info->ops[i].filesystem_ops);
|
||||||
|
info->plugin_memory_free(info->ops[i].random_access_file_ops);
|
||||||
|
info->plugin_memory_free(info->ops[i].writable_file_ops);
|
||||||
|
info->plugin_memory_free(info->ops[i].read_only_memory_region_ops);
|
||||||
|
}
|
||||||
|
info->plugin_memory_free(info->ops);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace filesystem_registration
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace filesystem_registration {
|
||||||
|
|
||||||
|
// Implementation for filesystem registration
|
||||||
|
//
|
||||||
|
// Don't call this directly. Instead call `RegisterFilesystemPlugin`.
|
||||||
|
// Exposed only for static registration of local filesystems.
|
||||||
|
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info);
|
||||||
|
|
||||||
|
} // namespace filesystem_registration
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,7 @@ tf_cc_shared_object(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "posix_filesystem_impl",
|
name = "posix_filesystem_impl",
|
||||||
srcs = ["posix_filesystem.cc"],
|
srcs = ["posix_filesystem.cc"],
|
||||||
|
hdrs = ["posix_filesystem.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":posix_filesystem_helper",
|
":posix_filesystem_helper",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
@ -26,6 +27,20 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Since building pip package and API tests require a filesystem, we provide a
|
||||||
|
# static registration target that they should link against.
|
||||||
|
cc_library(
|
||||||
|
name = "posix_filesystem_static",
|
||||||
|
srcs = ["posix_filesystem_static.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":posix_filesystem_impl",
|
||||||
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
|
"//tensorflow/c/experimental/filesystem:modular_filesystem",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
# Library implementing helper functionality, so that the above only contains
|
# Library implementing helper functionality, so that the above only contains
|
||||||
# the API implementation for modular filesystems.
|
# the API implementation for modular filesystems.
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
|
||||||
|
|
||||||
#include <dirent.h>
|
#include <dirent.h>
|
||||||
#include <errno.h>
|
#include <errno.h>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
@ -24,15 +26,15 @@ limitations under the License.
|
|||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
|
||||||
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
|
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
// Implementation of a filesystem for POSIX environments.
|
// Implementation of a filesystem for POSIX environments.
|
||||||
// This filesystem will support `file://` and empty (local) URI schemes.
|
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||||
|
|
||||||
|
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||||
|
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||||
|
|
||||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_random_access_file {
|
namespace tf_random_access_file {
|
||||||
@ -45,7 +47,9 @@ typedef struct PosixFile {
|
|||||||
static void Cleanup(TF_RandomAccessFile* file) {
|
static void Cleanup(TF_RandomAccessFile* file) {
|
||||||
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
||||||
close(posix_file->fd);
|
close(posix_file->fd);
|
||||||
free(const_cast<char*>(posix_file->filename));
|
// This would be safe to free using `free` directly as it is only opaque.
|
||||||
|
// However, it is better to be consistent everywhere.
|
||||||
|
plugin_memory_free(const_cast<char*>(posix_file->filename));
|
||||||
delete posix_file;
|
delete posix_file;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,7 +104,7 @@ typedef struct PosixFile {
|
|||||||
|
|
||||||
static void Cleanup(TF_WritableFile* file) {
|
static void Cleanup(TF_WritableFile* file) {
|
||||||
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
||||||
free(const_cast<char*>(posix_file->filename));
|
plugin_memory_free(const_cast<char*>(posix_file->filename));
|
||||||
delete posix_file;
|
delete posix_file;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,12 +387,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
|||||||
if (num_entries < 0) {
|
if (num_entries < 0) {
|
||||||
TF_SetStatusFromIOError(status, errno, path);
|
TF_SetStatusFromIOError(status, errno, path);
|
||||||
} else {
|
} else {
|
||||||
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
|
*entries = static_cast<char**>(
|
||||||
|
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
|
||||||
for (int i = 0; i < num_entries; i++) {
|
for (int i = 0; i < num_entries; i++) {
|
||||||
(*entries)[i] = strdup(dir_entries[i]->d_name);
|
(*entries)[i] = strdup(dir_entries[i]->d_name);
|
||||||
free(dir_entries[i]);
|
plugin_memory_free(dir_entries[i]);
|
||||||
}
|
}
|
||||||
free(dir_entries);
|
plugin_memory_free(dir_entries);
|
||||||
}
|
}
|
||||||
|
|
||||||
return num_entries;
|
return num_entries;
|
||||||
@ -396,48 +401,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
} // namespace tf_posix_filesystem
|
} // namespace tf_posix_filesystem
|
||||||
|
|
||||||
void TF_InitPlugin(TF_Status* status) {
|
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||||
TF_RandomAccessFileOps random_access_file_ops = {
|
const char* uri) {
|
||||||
tf_random_access_file::Cleanup,
|
TF_SetFilesystemVersionMetadata(ops);
|
||||||
tf_random_access_file::Read,
|
ops->scheme = strdup(uri);
|
||||||
};
|
|
||||||
TF_WritableFileOps writable_file_ops = {
|
|
||||||
tf_writable_file::Cleanup, tf_writable_file::Append,
|
|
||||||
tf_writable_file::Tell, tf_writable_file::Flush,
|
|
||||||
tf_writable_file::Sync, tf_writable_file::Close,
|
|
||||||
};
|
|
||||||
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
|
|
||||||
tf_read_only_memory_region::Cleanup,
|
|
||||||
tf_read_only_memory_region::Data,
|
|
||||||
tf_read_only_memory_region::Length,
|
|
||||||
};
|
|
||||||
TF_FilesystemOps filesystem_ops = {
|
|
||||||
tf_posix_filesystem::Init,
|
|
||||||
tf_posix_filesystem::Cleanup,
|
|
||||||
tf_posix_filesystem::NewRandomAccessFile,
|
|
||||||
tf_posix_filesystem::NewWritableFile,
|
|
||||||
tf_posix_filesystem::NewAppendableFile,
|
|
||||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
|
|
||||||
tf_posix_filesystem::CreateDir,
|
|
||||||
/*recursively_create_dir=*/nullptr,
|
|
||||||
tf_posix_filesystem::DeleteFile,
|
|
||||||
tf_posix_filesystem::DeleteDir,
|
|
||||||
/*delete_recursively=*/nullptr,
|
|
||||||
tf_posix_filesystem::RenameFile,
|
|
||||||
tf_posix_filesystem::CopyFile,
|
|
||||||
tf_posix_filesystem::PathExists,
|
|
||||||
/*paths_exist=*/nullptr,
|
|
||||||
tf_posix_filesystem::Stat,
|
|
||||||
/*is_directory=*/nullptr,
|
|
||||||
/*get_file_size=*/nullptr,
|
|
||||||
/*translate_name=*/nullptr,
|
|
||||||
tf_posix_filesystem::GetChildren,
|
|
||||||
/*get_matching_paths=*/nullptr,
|
|
||||||
/*flush_caches=*/nullptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const char* scheme : {"", "file"})
|
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||||
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
|
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||||
&random_access_file_ops, &writable_file_ops,
|
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||||
&read_only_memory_region_ops, status);
|
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||||
|
|
||||||
|
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||||
|
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||||
|
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||||
|
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||||
|
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||||
|
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||||
|
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||||
|
|
||||||
|
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
|
||||||
|
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
|
||||||
|
ops->read_only_memory_region_ops->cleanup =
|
||||||
|
tf_read_only_memory_region::Cleanup;
|
||||||
|
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
|
||||||
|
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
|
||||||
|
|
||||||
|
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||||
|
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||||
|
ops->filesystem_ops->init = tf_posix_filesystem::Init;
|
||||||
|
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
|
||||||
|
ops->filesystem_ops->new_random_access_file =
|
||||||
|
tf_posix_filesystem::NewRandomAccessFile;
|
||||||
|
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
|
||||||
|
ops->filesystem_ops->new_appendable_file =
|
||||||
|
tf_posix_filesystem::NewAppendableFile;
|
||||||
|
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||||
|
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||||
|
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
|
||||||
|
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
|
||||||
|
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
|
||||||
|
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
|
||||||
|
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
|
||||||
|
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
|
||||||
|
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
|
||||||
|
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||||
|
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||||
|
info->plugin_memory_free = plugin_memory_free;
|
||||||
|
info->num_schemes = 2;
|
||||||
|
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||||
|
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||||
|
ProvideFilesystemSupportFor(&info->ops[0], "");
|
||||||
|
ProvideFilesystemSupportFor(&info->ops[1], "file");
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
|
||||||
|
// Initialize the POSIX filesystem.
|
||||||
|
//
|
||||||
|
// In general, the `TF_InitPlugin` symbol doesn't need to be exposed in a header
|
||||||
|
// file, since the plugin registration will look for the symbol in the DSO file
|
||||||
|
// that provides the filesystem functionality. However, the POSIX filesystem
|
||||||
|
// needs to be statically registered in some tests and utilities for building
|
||||||
|
// the API files at the time of creating the pip package. Hence, we need to
|
||||||
|
// expose this function so that this filesystem can be statically registered
|
||||||
|
// when needed.
|
||||||
|
void TF_InitPlugin(TF_FilesystemPluginInfo* info);
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
|
@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Both files have been opened, do the transfer.
|
// Both files have been opened, do the transfer.
|
||||||
// Since errno would be overriden by `close` below, save it here.
|
// Since errno would be overridden by `close` below, save it here.
|
||||||
int error_code = 0;
|
int error_code = 0;
|
||||||
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;
|
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;
|
||||||
|
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||||
|
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Register the POSIX filesystems statically.
|
||||||
|
// Return value will be unused
|
||||||
|
bool StaticallyRegisterLocalFilesystems() {
|
||||||
|
TF_FilesystemPluginInfo info;
|
||||||
|
TF_InitPlugin(&info);
|
||||||
|
Status status = filesystem_registration::RegisterFilesystemPluginImpl(&info);
|
||||||
|
if (!status.ok()) {
|
||||||
|
VLOG(0) << "Static POSIX filesystem could not be registered: " << status;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the actual registration
|
||||||
|
static bool unused = StaticallyRegisterLocalFilesystems();
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Experimental windows filesystem plugin.
|
||||||
|
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||||
|
|
||||||
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filesystem implementation for Windows environment
|
||||||
|
tf_cc_shared_object(
|
||||||
|
name = "windows_filesystem.dll",
|
||||||
|
framework_so = [],
|
||||||
|
linkstatic = False,
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":windows_filesystem_impl"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The real implementation of the filesystem.
|
||||||
|
cc_library(
|
||||||
|
name = "windows_filesystem_impl",
|
||||||
|
srcs = ["windows_filesystem.cc"],
|
||||||
|
copts = get_win_copts(),
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:tf_status",
|
||||||
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,73 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
|
// Implementation of a filesystem for POSIX environments.
|
||||||
|
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||||
|
|
||||||
|
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||||
|
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||||
|
|
||||||
|
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_random_access_file {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_random_access_file
|
||||||
|
|
||||||
|
// SECTION 2. Implementation for `TF_WritableFile`
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_writable_file {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_writable_file
|
||||||
|
|
||||||
|
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_read_only_memory_region {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_read_only_memory_region
|
||||||
|
|
||||||
|
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_windows_filesystem {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_windows_filesystem
|
||||||
|
|
||||||
|
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||||
|
const char* uri) {
|
||||||
|
TF_SetFilesystemVersionMetadata(ops);
|
||||||
|
ops->scheme = strdup(uri);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||||
|
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||||
|
info->plugin_memory_free = plugin_memory_free;
|
||||||
|
info->num_schemes = 2;
|
||||||
|
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||||
|
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||||
|
ProvideFilesystemSupportFor(&info->ops[0], "");
|
||||||
|
ProvideFilesystemSupportFor(&info->ops[1], "file");
|
||||||
|
}
|
@ -27,14 +27,10 @@ namespace {
|
|||||||
|
|
||||||
class DummyDevice : public DeviceBase {
|
class DummyDevice : public DeviceBase {
|
||||||
public:
|
public:
|
||||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
|
||||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||||
return cpu_allocator();
|
return cpu_allocator();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
bool save_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
||||||
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
|||||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||||
|
|
||||||
OpKernelContext::Params params;
|
OpKernelContext::Params params;
|
||||||
DummyDevice dummy_device(nullptr, false);
|
DummyDevice dummy_device(nullptr);
|
||||||
params.device = &dummy_device;
|
params.device = &dummy_device;
|
||||||
params.op_kernel = kernel.get();
|
params.op_kernel = kernel.get();
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
@ -18,19 +18,40 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/kernels.h"
|
#include "tensorflow/c/kernels.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
struct MyCustomKernel {
|
struct MyCustomKernel {
|
||||||
bool created;
|
bool created;
|
||||||
@ -134,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
|||||||
|
|
||||||
class DummyDevice : public DeviceBase {
|
class DummyDevice : public DeviceBase {
|
||||||
public:
|
public:
|
||||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
|
||||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||||
return cpu_allocator();
|
return cpu_allocator();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
bool save_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(TestKernel, TestInputAndOutputCount) {
|
TEST(TestKernel, TestInputAndOutputCount) {
|
||||||
@ -202,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
OpKernelContext::Params p;
|
OpKernelContext::Params p;
|
||||||
DummyDevice dummy_device(nullptr, false);
|
DummyDevice dummy_device(nullptr);
|
||||||
p.device = &dummy_device;
|
p.device = &dummy_device;
|
||||||
p.step_id = 43;
|
p.step_id = 43;
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
|||||||
|
|
||||||
TEST(OpsTest, AttributeAccessors) {
|
TEST(OpsTest, AttributeAccessors) {
|
||||||
TF_OpDefinitionBuilder* builder =
|
TF_OpDefinitionBuilder* builder =
|
||||||
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
|
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
|
||||||
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
||||||
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
||||||
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
||||||
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
|
|||||||
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (const auto& op : op_list.op()) {
|
for (const auto& op : op_list.op()) {
|
||||||
if (op.name() == "AttributeAccesorsOp") {
|
if (op.name() == "AttributeAccessorsOp") {
|
||||||
ASSERT_TRUE(op.is_commutative());
|
ASSERT_TRUE(op.is_commutative());
|
||||||
ASSERT_TRUE(op.is_aggregate());
|
ASSERT_TRUE(op.is_aggregate());
|
||||||
ASSERT_TRUE(op.allows_uninitialized_input());
|
ASSERT_TRUE(op.allows_uninitialized_input());
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
|||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
|
||||||
|
const int64_t* dims, int num_dims, size_t len) {
|
||||||
|
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||||
|
for (int i = 0; i < num_dims; ++i) {
|
||||||
|
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||||
|
tensorflow::TensorInterface ret(
|
||||||
|
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||||
|
tensorflow::TensorShape(dimvec), buf));
|
||||||
|
buf->Unref();
|
||||||
|
size_t elem_size = TF_DataTypeSize(dtype);
|
||||||
|
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||||
int num_dims, size_t len) {
|
int num_dims, size_t len) {
|
||||||
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
||||||
tensorflow::cpu_allocator());
|
tensorflow::cpu_allocator());
|
||||||
return TF_NewTensor(dtype, dims, num_dims, data, len,
|
TF_ManagedBuffer* buf =
|
||||||
tensorflow::deallocate_buffer,
|
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
|
||||||
tensorflow::cpu_allocator());
|
tensorflow::cpu_allocator(), /*owns_memory=*/true);
|
||||||
|
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||||
void* data, size_t len,
|
void* data, size_t len,
|
||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg) {
|
void* deallocator_arg) {
|
||||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
|
||||||
for (int i = 0; i < num_dims; ++i) {
|
|
||||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ManagedBuffer* buf = nullptr;
|
TF_ManagedBuffer* buf = nullptr;
|
||||||
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
|
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
|
||||||
tensorflow::DataTypeCanUseMemcpy(
|
tensorflow::DataTypeCanUseMemcpy(
|
||||||
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
|||||||
// Other types have the same representation, so copy only if it is safe to
|
// Other types have the same representation, so copy only if it is safe to
|
||||||
// do so.
|
// do so.
|
||||||
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
||||||
len, tensorflow::deallocate_buffer, nullptr);
|
len, tensorflow::deallocate_buffer, nullptr,
|
||||||
|
/*owns_memory=*/true);
|
||||||
std::memcpy(buf->data(), data, len);
|
std::memcpy(buf->data(), data, len);
|
||||||
// Free the original buffer.
|
// Free the original buffer.
|
||||||
deallocator(data, len, deallocator_arg);
|
deallocator(data, len, deallocator_arg);
|
||||||
} else {
|
} else {
|
||||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||||
|
/*owns_memory=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||||
tensorflow::TensorInterface ret(
|
|
||||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
|
||||||
tensorflow::TensorShape(dimvec), buf));
|
|
||||||
buf->Unref();
|
|
||||||
size_t elem_size = TF_DataTypeSize(dtype);
|
|
||||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||||
@ -383,7 +393,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
|
|||||||
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
||||||
string(static_cast<const char*>(Data()), ByteSize()))) {
|
string(static_cast<const char*>(Data()), ByteSize()))) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
|
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -58,9 +58,9 @@ extern "C" {
|
|||||||
// start_offset: array[uint64]
|
// start_offset: array[uint64]
|
||||||
// data: byte[...]
|
// data: byte[...]
|
||||||
//
|
//
|
||||||
// The string length (as a varint), followed by the contents of the string
|
// The string length (as a varint, start_offset[i + 1] - start_offset[i]),
|
||||||
// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode
|
// followed by the contents of the string is encoded at data[start_offset[i]].
|
||||||
// facilitate this encoding.
|
// TF_StringEncode and TF_StringDecode facilitate this encoding.
|
||||||
|
|
||||||
typedef struct TF_Tensor TF_Tensor;
|
typedef struct TF_Tensor TF_Tensor;
|
||||||
|
|
||||||
|
@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
|||||||
public:
|
public:
|
||||||
TF_ManagedBuffer(void* data, size_t len,
|
TF_ManagedBuffer(void* data, size_t len,
|
||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg)
|
void* deallocator_arg, bool owns_memory)
|
||||||
: TensorBuffer(data),
|
: TensorBuffer(data),
|
||||||
len_(len),
|
len_(len),
|
||||||
deallocator_(deallocator),
|
deallocator_(deallocator),
|
||||||
deallocator_arg_(deallocator_arg) {}
|
deallocator_arg_(deallocator_arg),
|
||||||
|
owns_memory_(owns_memory) {}
|
||||||
|
|
||||||
~TF_ManagedBuffer() override {
|
~TF_ManagedBuffer() override {
|
||||||
(*deallocator_)(data(), len_, deallocator_arg_);
|
(*deallocator_)(data(), len_, deallocator_arg_);
|
||||||
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
|||||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevents input forwarding from mutating this buffer.
|
bool OwnsMemory() const override { return owns_memory_; }
|
||||||
bool OwnsMemory() const override { return false; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const size_t len_;
|
const size_t len_;
|
||||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||||
void* const deallocator_arg_;
|
void* const deallocator_arg_;
|
||||||
|
bool owns_memory_;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -41,6 +41,16 @@ filegroup(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_required_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"training/coordinator.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gradients",
|
name = "gradients",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -622,6 +632,7 @@ tf_gen_op_wrappers_cc(
|
|||||||
"tpu_configuration_ops",
|
"tpu_configuration_ops",
|
||||||
"tpu_cross_replica_ops",
|
"tpu_cross_replica_ops",
|
||||||
"tpu_embedding_ops",
|
"tpu_embedding_ops",
|
||||||
|
"tpu_embedding_load_retrieve_ops",
|
||||||
"tpu_functional_ops",
|
"tpu_functional_ops",
|
||||||
"tpu_heartbeat_ops",
|
"tpu_heartbeat_ops",
|
||||||
"tpu_host_compute_ops",
|
"tpu_host_compute_ops",
|
||||||
|
@ -41,7 +41,7 @@ class ClientSession::Impl {
|
|||||||
std::shared_ptr<Graph> graph_;
|
std::shared_ptr<Graph> graph_;
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
|
mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
ClientSession::ClientSession(const Scope& scope, const string& target)
|
ClientSession::ClientSession(const Scope& scope, const string& target)
|
||||||
|
@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
|
|||||||
// Used to identify nodes at which to stop backprop.
|
// Used to identify nodes at which to stop backprop.
|
||||||
std::unordered_set<int> GetStopBackpropNodes(
|
std::unordered_set<int> GetStopBackpropNodes(
|
||||||
const std::vector<bool>& reachable_nodes,
|
const std::vector<bool>& reachable_nodes,
|
||||||
const std::unordered_set<int>& output_nodes);
|
const std::unordered_set<int>& output_nodes) const;
|
||||||
|
|
||||||
const Scope& scope_;
|
const Scope& scope_;
|
||||||
const ops::GradOpRegistry* registry_;
|
const ops::GradOpRegistry* registry_;
|
||||||
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
|
|||||||
|
|
||||||
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
|
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
|
||||||
const std::vector<bool>& reachable_nodes,
|
const std::vector<bool>& reachable_nodes,
|
||||||
const std::unordered_set<int>& output_nodes) {
|
const std::unordered_set<int>& output_nodes) const {
|
||||||
// Output nodes that get transitively consumed by other `outputs_` are stored
|
// Output nodes that get transitively consumed by other `outputs_` are stored
|
||||||
// in `internal_outputs`.
|
// in `internal_outputs`.
|
||||||
std::unordered_set<int> internal_outputs;
|
std::unordered_set<int> internal_outputs;
|
||||||
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
|
|||||||
"Unable to find backprop list for node.id ", src.node()->name());
|
"Unable to find backprop list for node.id ", src.node()->name());
|
||||||
}
|
}
|
||||||
const auto& grads = iter->second;
|
const auto& grads = iter->second;
|
||||||
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
|
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
|
||||||
// Return any valid backproped gradients that remain after filtering,
|
// Return any valid backpropped gradients that remain after filtering,
|
||||||
// or 'NoGradient' otherwise.
|
// or 'NoGradient' otherwise.
|
||||||
std::vector<Output> grads_to_keep;
|
std::vector<Output> grads_to_keep;
|
||||||
for (const Output& o : grads) {
|
for (const Output& o : grads) {
|
||||||
@ -519,17 +519,17 @@ Status SymbolicGradientBuilder::AddGradients() {
|
|||||||
// Backprop along the in edges.
|
// Backprop along the in edges.
|
||||||
// TODO(andydavis) Find cleaner way to map each grad output returned by
|
// TODO(andydavis) Find cleaner way to map each grad output returned by
|
||||||
// gradient function to the src node/output to which it should be
|
// gradient function to the src node/output to which it should be
|
||||||
// backproped. Maybe grad functions can return a vector of Output pairs to
|
// backpropped. Maybe grad functions can return a vector of Output pairs to
|
||||||
// make this association explicit.
|
// make this association explicit.
|
||||||
size_t dx_index = 0;
|
|
||||||
for (const Edge* e : n->in_edges()) {
|
for (const Edge* e : n->in_edges()) {
|
||||||
if (e->IsControlEdge()) continue;
|
if (e->IsControlEdge()) continue;
|
||||||
if (dx_index == dx.size()) {
|
int dx_index = e->dst_input();
|
||||||
|
if (dx_index >= dx.size()) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Invalid gradient output index: ", dx_index, " size: ", dx.size());
|
"Invalid gradient output index: ", dx_index, " size: ", dx.size());
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()}));
|
BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
|
|||||||
EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
|
EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GradientsTest, AddSymbolicGradientsTest) {
|
||||||
|
Scope scope = Scope::NewRootScope();
|
||||||
|
for (int cnt = 0; cnt < 100; ++cnt) {
|
||||||
|
int N = 5 + rand() % 10;
|
||||||
|
// Construct forward graph.
|
||||||
|
OutputList inputs;
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
auto a = Const(scope, i, {1});
|
||||||
|
inputs.push_back(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pack = Stack(scope, inputs);
|
||||||
|
TF_ASSERT_OK(scope.status());
|
||||||
|
|
||||||
|
// Construct grad inputs.
|
||||||
|
OutputList output_grads;
|
||||||
|
Tensor ts(DT_INT32, {N, 1});
|
||||||
|
auto v = ts.matrix<int32>();
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
v(i, 0) = i;
|
||||||
|
}
|
||||||
|
auto dy = Const(scope, ts);
|
||||||
|
output_grads.push_back(dy);
|
||||||
|
// Call AddSymbolicGradients.
|
||||||
|
std::vector<Output> grad_outputs;
|
||||||
|
TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs,
|
||||||
|
output_grads, &grad_outputs));
|
||||||
|
ClientSession session((scope));
|
||||||
|
std::vector<Tensor> in_grad;
|
||||||
|
TF_ASSERT_OK(session.Run(grad_outputs, &in_grad));
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
test::ExpectTensorEqual<int>(in_grad[i], test::AsTensor<int>({i}, {1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
|
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
|
||||||
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
|
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
|
||||||
// a single nodes output.
|
// a single nodes output.
|
||||||
|
@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
|
|||||||
// Multiply after broadcasting vec to match dimensions of mat.
|
// Multiply after broadcasting vec to match dimensions of mat.
|
||||||
// Args:
|
// Args:
|
||||||
// vec: A 1-D tensor of dimension [D0]
|
// vec: A 1-D tensor of dimension [D0]
|
||||||
// mat: A 2-D tensor of dimesnion [D0, D1]
|
// mat: A 2-D tensor of dimension [D0, D1]
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
||||||
|
@ -68,6 +68,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,3 +225,15 @@ filegroup(
|
|||||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
exports_files(
|
||||||
|
glob([
|
||||||
|
"testdata/half_plus_two_pbtxt/**",
|
||||||
|
"testdata/half_plus_two_main_op/**",
|
||||||
|
"testdata/half_plus_two/**",
|
||||||
|
"testdata/half_plus_two_v2/**",
|
||||||
|
"testdata/x_plus_y_v2_debuginfo/**",
|
||||||
|
"testdata/CyclicModule/**",
|
||||||
|
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
@ -21,15 +21,22 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kTestDataPbTxt[] =
|
string TestDataPbTxt() {
|
||||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
constexpr char kTestDataSharded[] =
|
"half_plus_two_pbtxt", "00000123");
|
||||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
}
|
||||||
|
|
||||||
|
string TestDataSharded() {
|
||||||
|
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
|
"half_plus_two", "00000123");
|
||||||
|
}
|
||||||
|
|
||||||
class ReaderTest : public ::testing::Test {
|
class ReaderTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
|
|||||||
TEST_F(ReaderTest, TagMatch) {
|
TEST_F(ReaderTest, TagMatch) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def));
|
&meta_graph_def));
|
||||||
CheckMetaGraphDef(meta_graph_def);
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
|
|||||||
TEST_F(ReaderTest, NoTagMatch) {
|
TEST_F(ReaderTest, NoTagMatch) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
|||||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(
|
Status st = ReadMetaGraphDefFromSavedModel(
|
||||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
|||||||
TEST_F(ReaderTest, PbtxtFormat) {
|
TEST_F(ReaderTest, PbtxtFormat) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
|
||||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def));
|
&meta_graph_def));
|
||||||
CheckMetaGraphDef(meta_graph_def);
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
|
|||||||
TEST_F(ReaderTest, InvalidExportPath) {
|
TEST_F(ReaderTest, InvalidExportPath) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath("missing-path");
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
|
@ -114,14 +114,14 @@ class Coordinator {
|
|||||||
condition_variable wait_for_stop_;
|
condition_variable wait_for_stop_;
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
bool should_stop_ GUARDED_BY(mu_);
|
bool should_stop_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
mutex status_lock_;
|
mutex status_lock_;
|
||||||
Status status_ GUARDED_BY(status_lock_);
|
Status status_ TF_GUARDED_BY(status_lock_);
|
||||||
|
|
||||||
mutable mutex runners_lock_;
|
mutable mutex runners_lock_;
|
||||||
std::vector<std::unique_ptr<RunnerInterface>> runners_
|
std::vector<std::unique_ptr<RunnerInterface>> runners_
|
||||||
GUARDED_BY(runners_lock_);
|
TF_GUARDED_BY(runners_lock_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
|
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
|
||||||
};
|
};
|
||||||
|
@ -119,8 +119,8 @@ class QueueRunner : public RunnerInterface {
|
|||||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
int runs_ = 0;
|
int runs_ = 0;
|
||||||
Status status_ GUARDED_BY(mu_);
|
Status status_ TF_GUARDED_BY(mu_);
|
||||||
Status enqueue_status_ GUARDED_BY(mu_);
|
Status enqueue_status_ TF_GUARDED_BY(mu_);
|
||||||
std::unique_ptr<BlockingCounter> counter_;
|
std::unique_ptr<BlockingCounter> counter_;
|
||||||
|
|
||||||
Coordinator* coord_;
|
Coordinator* coord_;
|
||||||
@ -131,7 +131,7 @@ class QueueRunner : public RunnerInterface {
|
|||||||
std::vector<std::function<void(Status)>> callbacks_;
|
std::vector<std::function<void(Status)>> callbacks_;
|
||||||
|
|
||||||
mutable std::unique_ptr<mutex> cg_mu_;
|
mutable std::unique_ptr<mutex> cg_mu_;
|
||||||
std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
|
std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_);
|
||||||
RunOptions run_options_;
|
RunOptions run_options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,9 +20,11 @@ from __future__ import print_function as _print_function
|
|||||||
|
|
||||||
import logging as _logging
|
import logging as _logging
|
||||||
import os as _os
|
import os as _os
|
||||||
|
import six as _six
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
|
|
||||||
@ -36,20 +38,19 @@ try:
|
|||||||
from tensorboard.summary._tf import summary
|
from tensorboard.summary._tf import summary
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||||
# Make sure we get the correct summary module with lazy loading
|
|
||||||
setattr(_current_module, "summary", summary)
|
setattr(_current_module, "summary", summary)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||||
"installation.")
|
"installation.")
|
||||||
|
|
||||||
try:
|
# Lazy-load estimator.
|
||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||||
_current_module.__path__ = (
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
setattr(_current_module, "estimator", estimator)
|
if _module_dir:
|
||||||
except ImportError:
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
pass
|
setattr(_current_module, "estimator", estimator)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tensorflow.python.keras.api._v2 import keras
|
from tensorflow.python.keras.api._v2 import keras
|
||||||
@ -59,6 +60,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
|
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
|
||||||
#
|
#
|
||||||
|
@ -20,8 +20,10 @@ from __future__ import print_function as _print_function
|
|||||||
|
|
||||||
import os as _os
|
import os as _os
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
import six as _six
|
||||||
|
|
||||||
from tensorflow.python.tools import module_util as _module_util
|
from tensorflow.python.tools import module_util as _module_util
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
|
|
||||||
@ -31,13 +33,14 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# Hook external TensorFlow modules.
|
# Hook external TensorFlow modules.
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
try:
|
|
||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
# Lazy-load estimator.
|
||||||
_current_module.__path__ = (
|
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||||
setattr(_current_module, "estimator", estimator)
|
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||||
except ImportError:
|
if _module_dir:
|
||||||
pass
|
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tensorflow.python.keras.api._v1 import keras
|
from tensorflow.python.keras.api._v1 import keras
|
||||||
@ -47,6 +50,14 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
if not _six.PY2:
|
||||||
|
import typing as _typing
|
||||||
|
if _typing.TYPE_CHECKING:
|
||||||
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||||
|
@ -1,11 +1,5 @@
|
|||||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
|
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
|
||||||
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
|
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
@ -39,9 +33,11 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":aot_only_var_handle_op",
|
":aot_only_var_handle_op",
|
||||||
":embedded_protocol_buffers",
|
":embedded_protocol_buffers",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
|
||||||
"//tensorflow/compiler/tf2xla",
|
"//tensorflow/compiler/tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||||
@ -69,36 +65,7 @@ cc_library(
|
|||||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||||
"@llvm-project//llvm:target",
|
"@llvm-project//llvm:target",
|
||||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
] + if_llvm_aarch64_available([
|
"//tensorflow/core:regexp_internal",
|
||||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Necessary for the pywrap inclusion below.
|
|
||||||
tf_pybind_cc_library_wrapper(
|
|
||||||
name = "tfcompile_headers_lib",
|
|
||||||
deps = [
|
|
||||||
":tfcompile_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_python_pybind_extension(
|
|
||||||
name = "_pywrap_tfcompile",
|
|
||||||
srcs = ["tfcompile_wrapper.cc"],
|
|
||||||
features = ["-layering_check"],
|
|
||||||
module_name = "_pywrap_tfcompile",
|
|
||||||
visibility = ["//tensorflow/python:__pkg__"],
|
|
||||||
deps = [
|
|
||||||
":tfcompile_headers_lib",
|
|
||||||
"@pybind11",
|
|
||||||
"//third_party/python_runtime:headers",
|
|
||||||
"//tensorflow/python:pybind11_lib",
|
|
||||||
"//tensorflow/python:pybind11_status",
|
|
||||||
# These headers cannot be brought in via cc_header_only_library
|
|
||||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:target",
|
|
||||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
|
||||||
] + if_llvm_aarch64_available([
|
] + if_llvm_aarch64_available([
|
||||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||||
]),
|
]),
|
||||||
@ -119,6 +86,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:support", # fixdeps: keep
|
"@llvm-project//llvm:support", # fixdeps: keep
|
||||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
@ -131,6 +99,19 @@ tf_cc_binary(
|
|||||||
deps = [":tfcompile_main"],
|
deps = [":tfcompile_main"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "llvm_targets",
|
||||||
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
|
deps = [
|
||||||
|
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||||
|
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||||
|
"@llvm-project//llvm:target",
|
||||||
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
|
] + if_llvm_aarch64_available([
|
||||||
|
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tfcompile_main",
|
name = "tfcompile_main",
|
||||||
srcs = ["tfcompile_main.cc"],
|
srcs = ["tfcompile_main.cc"],
|
||||||
@ -156,54 +137,108 @@ cc_library(
|
|||||||
# tfcompile.bzl correctly handles usage from outside of the package that it is
|
# tfcompile.bzl correctly handles usage from outside of the package that it is
|
||||||
# defined in.
|
# defined in.
|
||||||
|
|
||||||
# A simple test of tf_library from a text protobuf, mostly to enable the
|
# A simple test of tf_library from a text protobuf, to enable benchmark_test.
|
||||||
# benchmark_test.
|
# This test uses an incompleted graph with a node that is not defined. The
|
||||||
|
# compilation works because the undefined node is a feed node.
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfadd",
|
name = "test_graph_tfadd",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
config = "test_graph_tfadd.config.pbtxt",
|
config = "test_graph_tfadd.config.pbtxt",
|
||||||
cpp_class = "AddComp",
|
cpp_class = "AddComp",
|
||||||
graph = "test_graph_tfadd.pbtxt",
|
graph = "test_graph_tfadd.pbtxt",
|
||||||
|
mlir_components = "None",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfadd_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfadd.config.pbtxt",
|
||||||
|
cpp_class = "AddComp",
|
||||||
|
graph = "test_graph_tfadd.pbtxt",
|
||||||
|
mlir_components = "Bridge",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# A test of tf_library that includes a graph with an unknown op, but where
|
# A test of tf_library that includes a graph with an unknown op, but where
|
||||||
# the compilation works because the unknown op is not needed for the fetches.
|
# the compilation works because the node with the unknown op is not needed
|
||||||
|
# for the fetches.
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfunknownop",
|
name = "test_graph_tfunknownop",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
config = "test_graph_tfunknownop.config.pbtxt",
|
config = "test_graph_tfunknownop.config.pbtxt",
|
||||||
cpp_class = "UnknownOpAddComp",
|
cpp_class = "UnknownOpAddComp",
|
||||||
graph = "test_graph_tfunknownop.pbtxt",
|
graph = "test_graph_tfunknownop.pbtxt",
|
||||||
|
mlir_components = "None",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfunknownop_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfunknownop.config.pbtxt",
|
||||||
|
cpp_class = "UnknownOpAddComp",
|
||||||
|
graph = "test_graph_tfunknownop.pbtxt",
|
||||||
|
mlir_components = "Bridge",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# A test of tf_library that includes a graph with an unknown op, but where
|
# A test of tf_library that includes a graph with an unknown op, but where
|
||||||
# the compilation works because the op between the unknown op and the
|
# the compilation works because the node with the unknown op is only used as
|
||||||
# fetches is a feed.
|
# an input of a feed node.
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfunknownop2",
|
name = "test_graph_tfunknownop2",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
config = "test_graph_tfunknownop2.config.pbtxt",
|
config = "test_graph_tfunknownop2.config.pbtxt",
|
||||||
cpp_class = "UnknownOpAddComp",
|
cpp_class = "UnknownOpAddComp",
|
||||||
graph = "test_graph_tfunknownop.pbtxt",
|
graph = "test_graph_tfunknownop.pbtxt",
|
||||||
|
mlir_components = "None",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfunknownop2_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfunknownop2.config.pbtxt",
|
||||||
|
cpp_class = "UnknownOpAddComp",
|
||||||
|
graph = "test_graph_tfunknownop.pbtxt",
|
||||||
|
mlir_components = "Bridge",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# A test of tf_library that includes a graph with an unknown op, but where
|
# A test of tf_library that includes a graph with an unknown op, but where
|
||||||
# the compilation works because the unknown op is fed.
|
# the compilation works because the node with the unknown op is a feed node.
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfunknownop3",
|
name = "test_graph_tfunknownop3",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
config = "test_graph_tfunknownop3.config.pbtxt",
|
config = "test_graph_tfunknownop3.config.pbtxt",
|
||||||
cpp_class = "UnknownOpAddComp",
|
cpp_class = "UnknownOpAddComp",
|
||||||
graph = "test_graph_tfunknownop.pbtxt",
|
graph = "test_graph_tfunknownop.pbtxt",
|
||||||
|
mlir_components = "None",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfunknownop3_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfunknownop3.config.pbtxt",
|
||||||
|
cpp_class = "UnknownOpAddComp",
|
||||||
|
graph = "test_graph_tfunknownop.pbtxt",
|
||||||
|
mlir_components = "Bridge",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -283,9 +318,13 @@ test_suite(
|
|||||||
tests = [
|
tests = [
|
||||||
":benchmark_test",
|
":benchmark_test",
|
||||||
":codegen_test",
|
":codegen_test",
|
||||||
|
":test_graph_tfadd_mlir_bridge_test",
|
||||||
":test_graph_tfadd_test",
|
":test_graph_tfadd_test",
|
||||||
|
":test_graph_tfunknownop2_mlir_bridge_test",
|
||||||
":test_graph_tfunknownop2_test",
|
":test_graph_tfunknownop2_test",
|
||||||
|
":test_graph_tfunknownop3_mlir_bridge_test",
|
||||||
":test_graph_tfunknownop3_test",
|
":test_graph_tfunknownop3_test",
|
||||||
|
":test_graph_tfunknownop_mlir_bridge_test",
|
||||||
":test_graph_tfunknownop_test",
|
":test_graph_tfunknownop_test",
|
||||||
"//tensorflow/compiler/aot/tests:all_tests",
|
"//tensorflow/compiler/aot/tests:all_tests",
|
||||||
],
|
],
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
@ -169,7 +170,9 @@ Status GenArgMethods(const tf2xla::Config& config,
|
|||||||
const xla::ProgramShapeProto& ps,
|
const xla::ProgramShapeProto& ps,
|
||||||
const CompileResult& compile_result, string* methods) {
|
const CompileResult& compile_result, string* methods) {
|
||||||
size_t num_args = ps.parameters_size();
|
size_t num_args = ps.parameters_size();
|
||||||
if (config.feed_size() + config.variable_size() != num_args) {
|
// feed_size() + variable_size() is the maximum number of args as an
|
||||||
|
// implementation may not create an argument for an unused variable.
|
||||||
|
if (config.feed_size() + config.variable_size() < num_args) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"mismatch between feed_size(", config.feed_size(), ")+variable_size(",
|
"mismatch between feed_size(", config.feed_size(), ")+variable_size(",
|
||||||
config.variable_size(), ") and num_args(", num_args, ")");
|
config.variable_size(), ") and num_args(", num_args, ")");
|
||||||
@ -288,8 +291,8 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generates code implementing {Arg,Result}Names(), where T is one of
|
// Generates code implementing {Arg,Result}Names(), where T is one of
|
||||||
// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
|
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
|
||||||
// literal in the array, with nullptr terminating the array.
|
// string literal in the array, with nullptr terminating the array.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string GenNameToIndexCode(const T& entries, bool generate) {
|
string GenNameToIndexCode(const T& entries, bool generate) {
|
||||||
// No need for a static array if we're not supposed to generate the data.
|
// No need for a static array if we're not supposed to generate the data.
|
||||||
@ -419,6 +422,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||||||
// Generate metadata.
|
// Generate metadata.
|
||||||
const string arg_names_code =
|
const string arg_names_code =
|
||||||
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
|
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
|
||||||
|
|
||||||
|
auto variable_copy = config.variable();
|
||||||
|
for (auto& var : variable_copy) {
|
||||||
|
if (var.name().empty()) {
|
||||||
|
var.set_name(var.node_name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const string variable_names_code =
|
||||||
|
GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
|
||||||
|
|
||||||
const string result_names_code =
|
const string result_names_code =
|
||||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||||
const string include_xla_data_proto =
|
const string include_xla_data_proto =
|
||||||
@ -457,8 +470,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||||||
|
|
||||||
{{INCLUDE_XLA_DATA_PROTO}}
|
{{INCLUDE_XLA_DATA_PROTO}}
|
||||||
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
|
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
|
||||||
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
|
||||||
#include "{{TF_HEADER_ROOT}}/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace Eigen { struct ThreadPoolDevice; }
|
namespace Eigen { struct ThreadPoolDevice; }
|
||||||
namespace xla { class ExecutableRunOptions; }
|
namespace xla { class ExecutableRunOptions; }
|
||||||
@ -507,6 +520,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
// Number of input arguments for the compiled computation.
|
// Number of input arguments for the compiled computation.
|
||||||
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
||||||
|
|
||||||
|
// Number of variables for the compiled computation.
|
||||||
|
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
|
||||||
|
|
||||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||||
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
||||||
@ -522,8 +538,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
set_static_data_num_buffers(data, kNumBuffers);
|
set_static_data_num_buffers(data, kNumBuffers);
|
||||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||||
set_static_data_num_args(data, kNumArgs);
|
set_static_data_num_args(data, kNumArgs);
|
||||||
|
set_static_data_num_variables(data, kNumVariables);
|
||||||
set_static_data_result_index(data, kResultIndex);
|
set_static_data_result_index(data, kResultIndex);
|
||||||
set_static_data_arg_names(data, StaticArgNames());
|
set_static_data_arg_names(data, StaticArgNames());
|
||||||
|
set_static_data_variable_names(data, StaticVariableNames());
|
||||||
set_static_data_result_names(data, StaticResultNames());
|
set_static_data_result_names(data, StaticResultNames());
|
||||||
set_static_data_program_shape(data, StaticProgramShape());
|
set_static_data_program_shape(data, StaticProgramShape());
|
||||||
set_static_data_hlo_profile_printer_data(
|
set_static_data_hlo_profile_printer_data(
|
||||||
@ -626,6 +644,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
// Array of names of each positional argument, terminated by nullptr.
|
// Array of names of each positional argument, terminated by nullptr.
|
||||||
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
||||||
|
|
||||||
|
// Array of names of each positional variable, terminated by nullptr.
|
||||||
|
static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
|
||||||
|
|
||||||
// Array of names of each positional result, terminated by nullptr.
|
// Array of names of each positional result, terminated by nullptr.
|
||||||
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
||||||
|
|
||||||
@ -654,12 +675,12 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
||||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||||
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
||||||
|
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
|
||||||
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
||||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||||
{"{{CLASS}}", opts.class_name},
|
{"{{CLASS}}", opts.class_name},
|
||||||
{"{{DECLS_FROM_OBJ_FILE}}",
|
{"{{DECLS_FROM_OBJ_FILE}}",
|
||||||
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
||||||
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
|
|
||||||
{"{{ENTRY}}", compile_result.entry_point},
|
{"{{ENTRY}}", compile_result.entry_point},
|
||||||
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
||||||
metadata_result.hlo_profile_printer_data_access_shim},
|
metadata_result.hlo_profile_printer_data_access_shim},
|
||||||
@ -674,6 +695,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
||||||
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
||||||
metadata_result.program_shape_access_shim},
|
metadata_result.program_shape_access_shim},
|
||||||
|
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
|
||||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||||
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
||||||
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/aot/codegen.h"
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -139,14 +141,23 @@ TEST_F(ParseCppClassTest, ParseFail) {
|
|||||||
|
|
||||||
static void CompareWithGoldenFile(
|
static void CompareWithGoldenFile(
|
||||||
const string& tensorflow_relative_golden_file_name,
|
const string& tensorflow_relative_golden_file_name,
|
||||||
const string& expected_contents) {
|
const string& expected_contents, bool ignore_cr) {
|
||||||
|
// Get rid of all CR characters, we may be running under windows.
|
||||||
|
string sanitized_expected_contents(expected_contents);
|
||||||
|
if (ignore_cr) {
|
||||||
|
sanitized_expected_contents.erase(
|
||||||
|
std::remove(sanitized_expected_contents.begin(),
|
||||||
|
sanitized_expected_contents.end(), '\r'),
|
||||||
|
sanitized_expected_contents.end());
|
||||||
|
}
|
||||||
|
|
||||||
// To update the golden file, flip update_golden to true and run the
|
// To update the golden file, flip update_golden to true and run the
|
||||||
// following:
|
// following:
|
||||||
// bazel test --test_strategy=local \
|
// bazel test --test_strategy=local \
|
||||||
// third_party/tensorflow/compiler/aot:codegen_test
|
// "third_party/tensorflow/compiler/aot:codegen_test"
|
||||||
const bool update_golden = false;
|
const bool update_golden = false;
|
||||||
const string golden_file_name = io::JoinPath(
|
string golden_file_name =
|
||||||
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
|
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||||
|
|
||||||
if (update_golden) {
|
if (update_golden) {
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(
|
||||||
@ -156,6 +167,11 @@ static void CompareWithGoldenFile(
|
|||||||
string golden_file_contents;
|
string golden_file_contents;
|
||||||
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
||||||
&golden_file_contents));
|
&golden_file_contents));
|
||||||
|
if (ignore_cr) {
|
||||||
|
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
|
||||||
|
golden_file_contents.end(), '\r'),
|
||||||
|
golden_file_contents.end());
|
||||||
|
}
|
||||||
EXPECT_EQ(golden_file_contents, expected_contents);
|
EXPECT_EQ(golden_file_contents, expected_contents);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,15 +213,20 @@ TEST(CodegenTest, Golden) {
|
|||||||
variable3->mutable_shape()->add_dim()->set_size(5);
|
variable3->mutable_shape()->add_dim()->set_size(5);
|
||||||
variable3->set_type(DT_INT32);
|
variable3->set_type(DT_INT32);
|
||||||
CompileResult compile_result;
|
CompileResult compile_result;
|
||||||
compile_result.tensorflow_header_root = "third_party/tensorflow";
|
|
||||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||||
{},
|
{},
|
||||||
{BufferInfo::MakeTempBuffer(1),
|
{BufferInfo::MakeTempBuffer(1),
|
||||||
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
||||||
BufferInfo::MakeTempBuffer(2),
|
BufferInfo::MakeTempBuffer(1),
|
||||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
BufferInfo::MakeTempBuffer(1),
|
||||||
5, {}));
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
|
||||||
|
BufferInfo::MakeTempBuffer(1),
|
||||||
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
|
||||||
|
BufferInfo::MakeTempBuffer(1),
|
||||||
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
|
||||||
|
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
|
||||||
|
11, {}));
|
||||||
compile_result.program_shape =
|
compile_result.program_shape =
|
||||||
xla::ShapeUtil::MakeProgramShape(
|
xla::ShapeUtil::MakeProgramShape(
|
||||||
{
|
{
|
||||||
@ -230,14 +251,18 @@ TEST(CodegenTest, Golden) {
|
|||||||
// The other fields in metadata_result are tested as part of the generated
|
// The other fields in metadata_result are tested as part of the generated
|
||||||
// header test.
|
// header test.
|
||||||
|
|
||||||
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
|
// This specific golden test checks a binary file. It can potentially run into
|
||||||
metadata_result.object_file_data);
|
// issues due to ABIs not being stable, but has not so far.
|
||||||
|
// If we see any ABI issues, we should reconsider this specific test case.
|
||||||
|
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
|
||||||
|
metadata_result.object_file_data, false);
|
||||||
|
|
||||||
string header;
|
string header;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
||||||
|
|
||||||
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
|
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
|
||||||
|
true);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
|
@ -55,14 +55,17 @@ namespace bar {
|
|||||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
||||||
//
|
//
|
||||||
// Memory stats:
|
// Memory stats:
|
||||||
// arg bytes total: 104
|
// arg bytes total: 392
|
||||||
// arg bytes aligned: 192
|
// arg bytes aligned: 576
|
||||||
// temp bytes total: 126
|
// temp bytes total: 126
|
||||||
// temp bytes aligned: 320
|
// temp bytes aligned: 512
|
||||||
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
public:
|
public:
|
||||||
// Number of input arguments for the compiled computation.
|
// Number of input arguments for the compiled computation.
|
||||||
static constexpr size_t kNumArgs = 2;
|
static constexpr size_t kNumArgs = 5;
|
||||||
|
|
||||||
|
// Number of variables for the compiled computation.
|
||||||
|
static constexpr size_t kNumVariables = 3;
|
||||||
|
|
||||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||||
@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
set_static_data_num_buffers(data, kNumBuffers);
|
set_static_data_num_buffers(data, kNumBuffers);
|
||||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||||
set_static_data_num_args(data, kNumArgs);
|
set_static_data_num_args(data, kNumArgs);
|
||||||
|
set_static_data_num_variables(data, kNumVariables);
|
||||||
set_static_data_result_index(data, kResultIndex);
|
set_static_data_result_index(data, kResultIndex);
|
||||||
set_static_data_arg_names(data, StaticArgNames());
|
set_static_data_arg_names(data, StaticArgNames());
|
||||||
|
set_static_data_variable_names(data, StaticVariableNames());
|
||||||
set_static_data_result_names(data, StaticResultNames());
|
set_static_data_result_names(data, StaticResultNames());
|
||||||
set_static_data_program_shape(data, StaticProgramShape());
|
set_static_data_program_shape(data, StaticProgramShape());
|
||||||
set_static_data_hlo_profile_printer_data(
|
set_static_data_hlo_profile_printer_data(
|
||||||
@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Number of buffers for the compiled computation.
|
// Number of buffers for the compiled computation.
|
||||||
static constexpr size_t kNumBuffers = 6;
|
static constexpr size_t kNumBuffers = 12;
|
||||||
|
|
||||||
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
|
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
|
||||||
static const ::xla::cpu_function_runtime::BufferInfo
|
static const ::xla::cpu_function_runtime::BufferInfo
|
||||||
kBufferInfos[kNumBuffers] = {
|
kBufferInfos[kNumBuffers] = {
|
||||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
||||||
};
|
};
|
||||||
return kBufferInfos;
|
return kBufferInfos;
|
||||||
@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
|
|
||||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||||
1, 3
|
1, 3, 5, 7, 9
|
||||||
};
|
};
|
||||||
return kArgIndexToBufferIndex;
|
return kArgIndexToBufferIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The 0-based index of the result tuple in the temporary buffers.
|
// The 0-based index of the result tuple in the temporary buffers.
|
||||||
static constexpr size_t kResultIndex = 5;
|
static constexpr size_t kResultIndex = 11;
|
||||||
|
|
||||||
// Array of names of each positional argument, terminated by nullptr.
|
// Array of names of each positional argument, terminated by nullptr.
|
||||||
static const char** StaticArgNames() {
|
static const char** StaticArgNames() {
|
||||||
@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return kNames;
|
return kNames;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Array of names of each positional variable, terminated by nullptr.
|
||||||
|
static const char** StaticVariableNames() {
|
||||||
|
static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
|
||||||
|
return kNames;
|
||||||
|
}
|
||||||
|
|
||||||
// Array of names of each positional result, terminated by nullptr.
|
// Array of names of each positional result, terminated by nullptr.
|
||||||
static const char** StaticResultNames() {
|
static const char** StaticResultNames() {
|
||||||
static const char* kNames[] = {"myfetch", nullptr};
|
static const char* kNames[] = {"myfetch", nullptr};
|
||||||
|
@ -20,9 +20,11 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/base/call_once.h"
|
||||||
#include "llvm-c/Target.h"
|
#include "llvm-c/Target.h"
|
||||||
#include "tensorflow/compiler/aot/codegen.h"
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
#include "tensorflow/compiler/aot/flags.h"
|
#include "tensorflow/compiler/aot/flags.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
@ -38,6 +40,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -85,7 +88,6 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
|||||||
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
||||||
std::move(aot_or.ValueOrDie().back()));
|
std::move(aot_or.ValueOrDie().back()));
|
||||||
compile_result->entry_point = aot_opts.entry_point_name();
|
compile_result->entry_point = aot_opts.entry_point_name();
|
||||||
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
|
|
||||||
compile_result->pointer_size =
|
compile_result->pointer_size =
|
||||||
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -105,14 +107,17 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
|||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
xla::XlaComputation computation;
|
xla::XlaComputation computation;
|
||||||
if (flags.mlir_components == "Bridge") {
|
if (flags.mlir_components == "Bridge") {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
|
||||||
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
|
graph_def, config, &computation, flags.debug_info,
|
||||||
} else {
|
flags.debug_info_path_begin_marker));
|
||||||
if (!flags.mlir_components.empty()) {
|
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
|
||||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||||
client, &computation));
|
client, &computation));
|
||||||
|
} else {
|
||||||
|
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||||
|
}
|
||||||
|
if (flags.quantize) {
|
||||||
|
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
|
||||||
}
|
}
|
||||||
if (!flags.out_session_module.empty()) {
|
if (!flags.out_session_module.empty()) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||||
@ -130,8 +135,7 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
|||||||
xla::cpu::CpuAotCompilationOptions aot_opts(
|
xla::cpu::CpuAotCompilationOptions aot_opts(
|
||||||
flags.target_triple, flags.target_cpu, flags.target_features,
|
flags.target_triple, flags.target_cpu, flags.target_features,
|
||||||
flags.entry_point,
|
flags.entry_point,
|
||||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic,
|
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
|
||||||
flags.tensorflow_header_root);
|
|
||||||
|
|
||||||
return CompileXla(client, computation, aot_opts, compile_result);
|
return CompileXla(client, computation, aot_opts, compile_result);
|
||||||
}
|
}
|
||||||
@ -144,7 +148,7 @@ static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::once_flag targets_init;
|
static absl::once_flag targets_init;
|
||||||
|
|
||||||
static void InitializeTargets() {
|
static void InitializeTargets() {
|
||||||
// Initialize all LLVM targets so we can cross compile.
|
// Initialize all LLVM targets so we can cross compile.
|
||||||
@ -168,8 +172,25 @@ static void InitializeTargets() {
|
|||||||
LLVMInitializeX86AsmPrinter();
|
LLVMInitializeX86AsmPrinter();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replaces {{tag.type tag.name}} in the error message with tag_name.
|
||||||
|
// TODO(bixia): We currently only handlge tag.type == "node".
|
||||||
|
//
|
||||||
|
// In the error message, a graph node is represented as {{tag.type, tag.name}},
|
||||||
|
// to allow a Python debugger to insert source information about the graph node.
|
||||||
|
// For example, a Python add expression may be represented as
|
||||||
|
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
|
||||||
|
// in tensorflow/python/framework/error_interpolation.py for more detail.
|
||||||
|
static std::string InterpolateErrorMessage(std::string message) {
|
||||||
|
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
|
||||||
|
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
|
||||||
|
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
|
||||||
|
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
|
||||||
|
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
Status Main(const MainFlags& flags) {
|
Status Main(const MainFlags& flags) {
|
||||||
std::call_once(targets_init, &InitializeTargets);
|
absl::call_once(targets_init, &InitializeTargets);
|
||||||
|
|
||||||
// Process config.
|
// Process config.
|
||||||
tf2xla::Config config;
|
tf2xla::Config config;
|
||||||
@ -194,8 +215,13 @@ Status Main(const MainFlags& flags) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||||
CompileResult compile_result;
|
CompileResult compile_result;
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
Status status =
|
||||||
|
CompileGraph(std::move(graph_def), config, flags, &compile_result);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return Status(status.code(),
|
||||||
|
InterpolateErrorMessage(status.error_message()));
|
||||||
|
}
|
||||||
|
|
||||||
// Write output files.
|
// Write output files.
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
|
@ -35,7 +35,6 @@ struct CompileResult {
|
|||||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||||
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
||||||
string entry_point; // Name of generated function.
|
string entry_point; // Name of generated function.
|
||||||
string tensorflow_header_root; // Prefix for tensorflow headers.
|
|
||||||
int pointer_size = 0; // Size of a pointer in bytes.
|
int pointer_size = 0; // Size of a pointer in bytes.
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
|||||||
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
||||||
"be in the human-readable proto text format, otherwise it is expected "
|
"be in the human-readable proto text format, otherwise it is expected "
|
||||||
"to be in the proto binary format."},
|
"to be in the proto binary format."},
|
||||||
|
{"debug_info", &flags->debug_info,
|
||||||
|
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
|
||||||
|
"be in the human-readable proto text format, otherwise it is expected "
|
||||||
|
"to be in the proto binary format."},
|
||||||
|
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
|
||||||
|
"If not none, only keep the file path in the debug information after the"
|
||||||
|
" marker. The default value is empty"},
|
||||||
{"config", &flags->config,
|
{"config", &flags->config,
|
||||||
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
||||||
"is expected to be in the human-readable proto text format, otherwise "
|
"is expected to be in the human-readable proto text format, otherwise "
|
||||||
@ -70,12 +77,12 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
|||||||
"Output session module proto."},
|
"Output session module proto."},
|
||||||
{"mlir_components", &flags->mlir_components,
|
{"mlir_components", &flags->mlir_components,
|
||||||
"The MLIR components to enable. Currently only Bridge is supported."},
|
"The MLIR components to enable. Currently only Bridge is supported."},
|
||||||
|
{"quantize", &flags->quantize,
|
||||||
|
"If set, quantization will be applied before HLO code generation."},
|
||||||
{"gen_name_to_index", &flags->gen_name_to_index,
|
{"gen_name_to_index", &flags->gen_name_to_index,
|
||||||
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
||||||
{"gen_program_shape", &flags->gen_program_shape,
|
{"gen_program_shape", &flags->gen_program_shape,
|
||||||
"Generate program shape data for the ProgramShape method."},
|
"Generate program shape data for the ProgramShape method."},
|
||||||
{"tensorflow_header_root", &flags->tensorflow_header_root,
|
|
||||||
"Root directory of tensorflow headers."},
|
|
||||||
};
|
};
|
||||||
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
|
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
|
||||||
}
|
}
|
||||||
|
@ -28,6 +28,8 @@ namespace tfcompile {
|
|||||||
|
|
||||||
struct MainFlags {
|
struct MainFlags {
|
||||||
string graph;
|
string graph;
|
||||||
|
string debug_info;
|
||||||
|
string debug_info_path_begin_marker;
|
||||||
string config;
|
string config;
|
||||||
bool dump_fetch_nodes = false;
|
bool dump_fetch_nodes = false;
|
||||||
string target_triple;
|
string target_triple;
|
||||||
@ -40,7 +42,7 @@ struct MainFlags {
|
|||||||
string out_header;
|
string out_header;
|
||||||
string out_session_module;
|
string out_session_module;
|
||||||
string mlir_components;
|
string mlir_components;
|
||||||
string tensorflow_header_root;
|
bool quantize = false;
|
||||||
|
|
||||||
// C++ codegen options
|
// C++ codegen options
|
||||||
bool gen_name_to_index = false;
|
bool gen_name_to_index = false;
|
||||||
|
@ -1,11 +1,37 @@
|
|||||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
|
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:private"],
|
default_visibility = ["//visibility:private"],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
glob_lit_tests(
|
||||||
|
data = [":filecheck_test_utilities"],
|
||||||
|
driver = "@llvm-project//mlir:run_lit.sh",
|
||||||
|
tags_override = {
|
||||||
|
"test_error_message.lit.pbtxt": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
||||||
|
},
|
||||||
|
test_file_exts = ["lit.pbtxt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle together all of the test utilities that are used by tests.
|
||||||
|
filegroup(
|
||||||
|
name = "filecheck_test_utilities",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"test_error_message.lit.pbtxt.config.pbtxt",
|
||||||
|
"test_error_message.lit.pbtxt.debug.pbtxt",
|
||||||
|
"test_error_message.lit.pbtxt.fake_py.debug",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"//tensorflow/compiler/aot:tfcompile",
|
||||||
|
"@llvm-project//llvm:FileCheck",
|
||||||
|
"@llvm-project//llvm:not",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# We disable some tfcompile tests in the open source build with the
|
# We disable some tfcompile tests in the open source build with the
|
||||||
# "manual" tag to avoid making our OSS users build LLVM twice
|
# "manual" tag to avoid making our OSS users build LLVM twice
|
||||||
# (once for host and once for target).
|
# (once for host and once for target).
|
||||||
@ -25,6 +51,7 @@ test_suite(
|
|||||||
":test_graph_tfmatmulandadd_test",
|
":test_graph_tfmatmulandadd_test",
|
||||||
":test_graph_tfsplits_test",
|
":test_graph_tfsplits_test",
|
||||||
":test_graph_tftop_k_test",
|
":test_graph_tftop_k_test",
|
||||||
|
":test_graph_tfvariable_readonly_test",
|
||||||
":test_graph_tfvariable_sequential_updates_test",
|
":test_graph_tfvariable_sequential_updates_test",
|
||||||
":test_graph_tfvariable_test",
|
":test_graph_tfvariable_test",
|
||||||
":tfcompile_test",
|
":tfcompile_test",
|
||||||
@ -59,6 +86,7 @@ genrule(
|
|||||||
testonly = 1,
|
testonly = 1,
|
||||||
outs = [
|
outs = [
|
||||||
"test_graph_tfadd.pb",
|
"test_graph_tfadd.pb",
|
||||||
|
"test_debuginfo_tfadd.pb",
|
||||||
"test_graph_tfadd_with_ckpt.ckpt",
|
"test_graph_tfadd_with_ckpt.ckpt",
|
||||||
"test_graph_tfadd_with_ckpt.pb",
|
"test_graph_tfadd_with_ckpt.pb",
|
||||||
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||||
@ -73,6 +101,7 @@ genrule(
|
|||||||
"test_graph_tfsplits.pb",
|
"test_graph_tfsplits.pb",
|
||||||
"test_graph_tftop_k.pb",
|
"test_graph_tftop_k.pb",
|
||||||
"test_graph_tfvariable.pb",
|
"test_graph_tfvariable.pb",
|
||||||
|
"test_graph_tfvariable_readonly.pb",
|
||||||
"test_graph_tfvariable_sequential_updates.pb",
|
"test_graph_tfvariable_sequential_updates.pb",
|
||||||
],
|
],
|
||||||
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
|
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
|
||||||
@ -96,6 +125,7 @@ tf_library(
|
|||||||
# compile but the others in this directory succeed, you may need to
|
# compile but the others in this directory succeed, you may need to
|
||||||
# expand the "required by all tf_library targets" list in tfcompile.bzl.
|
# expand the "required by all tf_library targets" list in tfcompile.bzl.
|
||||||
include_standard_runtime_deps = False,
|
include_standard_runtime_deps = False,
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -108,6 +138,7 @@ tf_library(
|
|||||||
cpp_class = "AddWithCkptComp",
|
cpp_class = "AddWithCkptComp",
|
||||||
freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt",
|
freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt",
|
||||||
graph = "test_graph_tfadd_with_ckpt.pb",
|
graph = "test_graph_tfadd_with_ckpt.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -121,6 +152,7 @@ tf_library(
|
|||||||
freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt",
|
freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||||
freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver",
|
freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver",
|
||||||
graph = "test_graph_tfadd_with_ckpt_saver.pb",
|
graph = "test_graph_tfadd_with_ckpt_saver.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -132,6 +164,7 @@ tf_library(
|
|||||||
config = "test_graph_tfassert_eq.config.pbtxt",
|
config = "test_graph_tfassert_eq.config.pbtxt",
|
||||||
cpp_class = "AssertComp",
|
cpp_class = "AssertComp",
|
||||||
graph = "test_graph_tfassert_eq.pb",
|
graph = "test_graph_tfassert_eq.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -143,6 +176,7 @@ tf_library(
|
|||||||
config = "test_graph_tfcond.config.pbtxt",
|
config = "test_graph_tfcond.config.pbtxt",
|
||||||
cpp_class = "CondComp",
|
cpp_class = "CondComp",
|
||||||
graph = "test_graph_tfcond.pb",
|
graph = "test_graph_tfcond.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -154,6 +188,7 @@ tf_library(
|
|||||||
config = "test_graph_tffunction.config.pbtxt",
|
config = "test_graph_tffunction.config.pbtxt",
|
||||||
cpp_class = "FunctionComp",
|
cpp_class = "FunctionComp",
|
||||||
graph = "test_graph_tffunction.pb",
|
graph = "test_graph_tffunction.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -165,6 +200,7 @@ tf_library(
|
|||||||
config = "test_graph_tfgather.config.pbtxt",
|
config = "test_graph_tfgather.config.pbtxt",
|
||||||
cpp_class = "GatherComp",
|
cpp_class = "GatherComp",
|
||||||
graph = "test_graph_tfgather.pb",
|
graph = "test_graph_tfgather.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -176,6 +212,7 @@ tf_library(
|
|||||||
config = "test_graph_tfmatmul.config.pbtxt",
|
config = "test_graph_tfmatmul.config.pbtxt",
|
||||||
cpp_class = "foo::bar::MatMulComp",
|
cpp_class = "foo::bar::MatMulComp",
|
||||||
graph = "test_graph_tfmatmul.pb",
|
graph = "test_graph_tfmatmul.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -187,6 +224,7 @@ tf_library(
|
|||||||
config = "test_graph_tfmatmulandadd.config.pbtxt",
|
config = "test_graph_tfmatmulandadd.config.pbtxt",
|
||||||
cpp_class = "::foo::bar::MatMulAndAddComp",
|
cpp_class = "::foo::bar::MatMulAndAddComp",
|
||||||
graph = "test_graph_tfmatmulandadd.pb",
|
graph = "test_graph_tfmatmulandadd.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -200,6 +238,7 @@ tf_library(
|
|||||||
cpp_class = "MatMulAndAddCompWithProfiling",
|
cpp_class = "MatMulAndAddCompWithProfiling",
|
||||||
enable_xla_hlo_profiling = True,
|
enable_xla_hlo_profiling = True,
|
||||||
graph = "test_graph_tfmatmulandadd.pb",
|
graph = "test_graph_tfmatmulandadd.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -211,6 +250,7 @@ tf_library(
|
|||||||
config = "test_graph_tfsplits.config.pbtxt",
|
config = "test_graph_tfsplits.config.pbtxt",
|
||||||
cpp_class = "SplitsComp",
|
cpp_class = "SplitsComp",
|
||||||
graph = "test_graph_tfsplits.pb",
|
graph = "test_graph_tfsplits.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -222,6 +262,7 @@ tf_library(
|
|||||||
config = "test_graph_tftop_k.config.pbtxt",
|
config = "test_graph_tftop_k.config.pbtxt",
|
||||||
cpp_class = "TopKComp",
|
cpp_class = "TopKComp",
|
||||||
graph = "test_graph_tftop_k.pb",
|
graph = "test_graph_tftop_k.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -233,6 +274,19 @@ tf_library(
|
|||||||
config = "test_graph_tfvariable.config.pbtxt",
|
config = "test_graph_tfvariable.config.pbtxt",
|
||||||
cpp_class = "VariableComp",
|
cpp_class = "VariableComp",
|
||||||
graph = "test_graph_tfvariable.pb",
|
graph = "test_graph_tfvariable.pb",
|
||||||
|
mlir_components = "None",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfvariable_readonly",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||||
|
cpp_class = "VariableReadonlyComp",
|
||||||
|
graph = "test_graph_tfvariable_readonly.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -244,6 +298,7 @@ tf_library(
|
|||||||
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
|
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
|
||||||
cpp_class = "VariableSequentialUpdatesComp",
|
cpp_class = "VariableSequentialUpdatesComp",
|
||||||
graph = "test_graph_tfvariable_sequential_updates.pb",
|
graph = "test_graph_tfvariable_sequential_updates.pb",
|
||||||
|
mlir_components = "None",
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
],
|
],
|
||||||
@ -269,6 +324,7 @@ tf_cc_test(
|
|||||||
":test_graph_tfsplits",
|
":test_graph_tfsplits",
|
||||||
":test_graph_tftop_k",
|
":test_graph_tftop_k",
|
||||||
":test_graph_tfvariable",
|
":test_graph_tfvariable",
|
||||||
|
":test_graph_tfvariable_readonly",
|
||||||
":test_graph_tfvariable_sequential_updates",
|
":test_graph_tfvariable_sequential_updates",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
@ -288,6 +344,7 @@ tf_library(
|
|||||||
testonly = 1,
|
testonly = 1,
|
||||||
config = "test_graph_tfadd.config.pbtxt",
|
config = "test_graph_tfadd.config.pbtxt",
|
||||||
cpp_class = "AddComp",
|
cpp_class = "AddComp",
|
||||||
|
debug_info = "test_debuginfo_tfadd.pb",
|
||||||
graph = "test_graph_tfadd.pb",
|
graph = "test_graph_tfadd.pb",
|
||||||
include_standard_runtime_deps = False,
|
include_standard_runtime_deps = False,
|
||||||
mlir_components = "Bridge",
|
mlir_components = "Bridge",
|
||||||
@ -335,6 +392,18 @@ tf_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tffunction_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tffunction.config.pbtxt",
|
||||||
|
cpp_class = "FunctionComp",
|
||||||
|
graph = "test_graph_tffunction.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfassert_eq_mlir_bridge",
|
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
@ -421,6 +490,42 @@ tf_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfvariable_readonly_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||||
|
cpp_class = "VariableReadonlyComp",
|
||||||
|
graph = "test_graph_tfvariable_readonly.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfvariable_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfvariable.config.pbtxt",
|
||||||
|
cpp_class = "VariableComp",
|
||||||
|
graph = "test_graph_tfvariable.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfvariable_sequential_updates_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
|
||||||
|
cpp_class = "VariableSequentialUpdatesComp",
|
||||||
|
graph = "test_graph_tfvariable_sequential_updates.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "tfcompile_test_mlir_bridge",
|
name = "tfcompile_test_mlir_bridge",
|
||||||
srcs = ["tfcompile_test.cc"],
|
srcs = ["tfcompile_test.cc"],
|
||||||
@ -434,12 +539,16 @@ tf_cc_test(
|
|||||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||||
":test_graph_tfassert_eq_mlir_bridge",
|
":test_graph_tfassert_eq_mlir_bridge",
|
||||||
":test_graph_tfcond_mlir_bridge",
|
":test_graph_tfcond_mlir_bridge",
|
||||||
|
":test_graph_tffunction_mlir_bridge",
|
||||||
":test_graph_tfgather_mlir_bridge",
|
":test_graph_tfgather_mlir_bridge",
|
||||||
":test_graph_tfmatmul_mlir_bridge",
|
":test_graph_tfmatmul_mlir_bridge",
|
||||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||||
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
|
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
|
||||||
":test_graph_tfsplits_mlir_bridge",
|
":test_graph_tfsplits_mlir_bridge",
|
||||||
":test_graph_tftop_k_mlir_bridge",
|
":test_graph_tftop_k_mlir_bridge",
|
||||||
|
":test_graph_tfvariable_mlir_bridge",
|
||||||
|
":test_graph_tfvariable_readonly_mlir_bridge",
|
||||||
|
":test_graph_tfvariable_sequential_updates_mlir_bridge",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import error_interpolation
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -154,11 +155,22 @@ def tftop_k(_):
|
|||||||
array_ops.identity(output[1], name='indices')
|
array_ops.identity(output[1], name='indices')
|
||||||
|
|
||||||
|
|
||||||
def tfvariable(_):
|
def tfvariable_readonly(_):
|
||||||
x = variables.Variable(1000.0, name='x')
|
x = variables.Variable(1000.0, name='x')
|
||||||
|
unused_y = variables.Variable(1000.0, name='y')
|
||||||
old_x = x.value()
|
old_x = x.value()
|
||||||
with ops.control_dependencies([old_x]):
|
with ops.control_dependencies([old_x]):
|
||||||
new_x = x.assign_add(42.0)
|
new_value = math_ops.add(old_x, 42.0)
|
||||||
|
array_ops.identity(new_value, name='result')
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/147908587): Change x and the two constants back to have a scalar shape
|
||||||
|
# when the bug is fixed.
|
||||||
|
def tfvariable(_):
|
||||||
|
x = variables.Variable([1000.0], name='x', shape=[1])
|
||||||
|
old_x = x.value()
|
||||||
|
with ops.control_dependencies([old_x]):
|
||||||
|
new_x = x.assign_add([42.0])
|
||||||
array_ops.stack([old_x, new_x], name='result')
|
array_ops.stack([old_x, new_x], name='result')
|
||||||
|
|
||||||
|
|
||||||
@ -174,7 +186,22 @@ def tfvariable_sequential_updates(_):
|
|||||||
array_ops.identity(updates, name='result')
|
array_ops.identity(updates, name='result')
|
||||||
|
|
||||||
|
|
||||||
def write_graph(build_graph, out_dir):
|
def export_debug_info(exported_graph):
|
||||||
|
"""Exports debug information from a graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exported_graph: A Graph that has been created by tracing a saveable view.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
|
||||||
|
"""
|
||||||
|
exported_operations = []
|
||||||
|
for op in exported_graph.get_operations():
|
||||||
|
exported_operations.append(('', op))
|
||||||
|
return error_interpolation.create_graph_debug_info_def(exported_operations)
|
||||||
|
|
||||||
|
|
||||||
|
def write_graph(build_graph, out_dir, debug_info=False):
|
||||||
"""Build a graph using build_graph and write it out."""
|
"""Build a graph using build_graph and write it out."""
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -183,10 +210,19 @@ def write_graph(build_graph, out_dir):
|
|||||||
with open(filename, 'wb') as f:
|
with open(filename, 'wb') as f:
|
||||||
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
|
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
|
||||||
|
|
||||||
|
if debug_info:
|
||||||
|
filename_debuginfo = os.path.join(
|
||||||
|
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
|
||||||
|
test_debuginfo = export_debug_info(g)
|
||||||
|
with open(filename_debuginfo, 'wb') as f:
|
||||||
|
f.write(
|
||||||
|
six.ensure_binary(
|
||||||
|
test_debuginfo.SerializeToString(deterministic=True)))
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
control_flow_util.enable_control_flow_v2()
|
control_flow_util.enable_control_flow_v2()
|
||||||
write_graph(tfadd, FLAGS.out_dir)
|
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
|
||||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||||
write_graph(tfassert_eq, FLAGS.out_dir)
|
write_graph(tfassert_eq, FLAGS.out_dir)
|
||||||
@ -198,6 +234,7 @@ def main(_):
|
|||||||
write_graph(tfsplits, FLAGS.out_dir)
|
write_graph(tfsplits, FLAGS.out_dir)
|
||||||
write_graph(tftop_k, FLAGS.out_dir)
|
write_graph(tftop_k, FLAGS.out_dir)
|
||||||
write_graph(tfvariable, FLAGS.out_dir)
|
write_graph(tfvariable, FLAGS.out_dir)
|
||||||
|
write_graph(tfvariable_readonly, FLAGS.out_dir)
|
||||||
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)
|
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)
|
||||||
|
|
||||||
|
|
||||||
|
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
|
||||||
|
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
|
||||||
|
|
||||||
|
# Checks the error message produced by tfcompile with mlir_component
|
||||||
|
# Checks that source debug information is used in the output error message and
|
||||||
|
# the node x_y_sum = Add
|
||||||
|
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
|
||||||
|
# CHECK: math_ops.add(x, y, name='x_y_sum')
|
||||||
|
# CHECK: build_graph(out_dir)
|
||||||
|
|
||||||
|
# Checks the error message produced by tfcompile without mlir_component
|
||||||
|
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
|
||||||
|
# OLD: x_y_sum
|
||||||
|
|
||||||
|
node: {
|
||||||
|
name: "x"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr: {
|
||||||
|
key: "shape"
|
||||||
|
value: {
|
||||||
|
shape: {
|
||||||
|
dim: {
|
||||||
|
size: -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr: {
|
||||||
|
key: "dtype"
|
||||||
|
value: {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node: {
|
||||||
|
name: "y"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr: {
|
||||||
|
key: "shape"
|
||||||
|
value: {
|
||||||
|
shape: {
|
||||||
|
dim: {
|
||||||
|
size: -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr: {
|
||||||
|
key: "dtype"
|
||||||
|
value: {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node: {
|
||||||
|
name: "x_y_sum"
|
||||||
|
op: "Add"
|
||||||
|
input: "x"
|
||||||
|
input: "y"
|
||||||
|
attr: {
|
||||||
|
key: "T"
|
||||||
|
value: {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions: {
|
||||||
|
producer: 321
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
# Text form of tensorflow.tf2xla.Config proto.
|
||||||
|
feed {
|
||||||
|
id { node_name: "x" }
|
||||||
|
shape {
|
||||||
|
dim { size: 2 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
feed {
|
||||||
|
id { node_name: "y" }
|
||||||
|
shape {
|
||||||
|
dim { size: 3 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fetch {
|
||||||
|
id { node_name: "x_y_sum" }
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
|
||||||
|
traces: {
|
||||||
|
key: "x@"
|
||||||
|
value: {
|
||||||
|
file_line_cols: {
|
||||||
|
line: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
traces: {
|
||||||
|
key: "x_y_sum@"
|
||||||
|
value: {
|
||||||
|
file_line_cols: {
|
||||||
|
line: 3
|
||||||
|
}
|
||||||
|
file_line_cols: {
|
||||||
|
line: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
traces: {
|
||||||
|
key: "y@"
|
||||||
|
value: {
|
||||||
|
file_line_cols: {
|
||||||
|
line: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
x = value
|
||||||
|
y = value
|
||||||
|
math_ops.add(x, y, name='x_y_sum')
|
||||||
|
build_graph(out_dir)
|
@ -0,0 +1,20 @@
|
|||||||
|
# Text form of tensorflow.tf2xla.Config proto.
|
||||||
|
fetch {
|
||||||
|
id { node_name: "result" }
|
||||||
|
}
|
||||||
|
|
||||||
|
variable {
|
||||||
|
node_name: "x"
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
type: DT_FLOAT
|
||||||
|
readonly: true
|
||||||
|
}
|
||||||
|
|
||||||
|
variable {
|
||||||
|
node_name: "y"
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
type: DT_FLOAT
|
||||||
|
readonly: true
|
||||||
|
}
|
@ -32,12 +32,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
|
||||||
#else
|
#else
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||||
@ -52,6 +56,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -425,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
|
||||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
|
||||||
TEST(TFCompileTest, Function) {
|
TEST(TFCompileTest, Function) {
|
||||||
// The function is equivalent to an addition
|
// The function is equivalent to an addition
|
||||||
FunctionComp add_fn;
|
FunctionComp add_fn;
|
||||||
@ -441,7 +444,6 @@ TEST(TFCompileTest, Function) {
|
|||||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
TEST(TFCompileTest, Splits) {
|
TEST(TFCompileTest, Splits) {
|
||||||
Eigen::ThreadPool tp(1);
|
Eigen::ThreadPool tp(1);
|
||||||
@ -495,8 +497,20 @@ TEST(TFCompileTest, TopK) {
|
|||||||
EXPECT_EQ(expected_indices[1], fn.result1(1));
|
EXPECT_EQ(expected_indices[1], fn.result1(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
TEST(TFCompileTest, VariableReadonly) {
|
||||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
Eigen::ThreadPool tp(1);
|
||||||
|
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||||
|
|
||||||
|
VariableReadonlyComp fn;
|
||||||
|
float x = 23;
|
||||||
|
fn.set_var_x_data(&x);
|
||||||
|
|
||||||
|
fn.set_thread_pool(&device);
|
||||||
|
fn.Run();
|
||||||
|
EXPECT_EQ(fn.result0(), 65);
|
||||||
|
EXPECT_EQ(fn.var_x(), 23);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TFCompileTest, Variable) {
|
TEST(TFCompileTest, Variable) {
|
||||||
Eigen::ThreadPool tp(1);
|
Eigen::ThreadPool tp(1);
|
||||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||||
@ -569,7 +583,6 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
|
|||||||
fn.Run();
|
fn.Run();
|
||||||
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
||||||
// Assert is converted into a no-op in XLA, so there is no failure even if the
|
// Assert is converted into a no-op in XLA, so there is no failure even if the
|
||||||
|
@ -26,6 +26,7 @@ def tf_library(
|
|||||||
name,
|
name,
|
||||||
graph,
|
graph,
|
||||||
config,
|
config,
|
||||||
|
debug_info = None,
|
||||||
freeze_checkpoint = None,
|
freeze_checkpoint = None,
|
||||||
freeze_saver = None,
|
freeze_saver = None,
|
||||||
cpp_class = None,
|
cpp_class = None,
|
||||||
@ -37,7 +38,7 @@ def tf_library(
|
|||||||
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
|
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
|
||||||
include_standard_runtime_deps = True,
|
include_standard_runtime_deps = True,
|
||||||
enable_xla_hlo_profiling = False,
|
enable_xla_hlo_profiling = False,
|
||||||
mlir_components = None,
|
mlir_components = "None",
|
||||||
deps = None,
|
deps = None,
|
||||||
tags = []):
|
tags = []):
|
||||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||||
@ -88,8 +89,8 @@ def tf_library(
|
|||||||
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
|
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
|
||||||
program, and emit metadata that lets us pretty-print the gathered
|
program, and emit metadata that lets us pretty-print the gathered
|
||||||
profile counters.
|
profile counters.
|
||||||
mlir_components: When the value is "Bridge", use MLIR to translate
|
mlir_components: When the value is "None", no components use MLIR. When
|
||||||
GraphDef to HLO.
|
the value is "Bridge", use MLIR to translate GraphDef to HLO.
|
||||||
deps: a list of deps to include on the build rules for the generated
|
deps: a list of deps to include on the build rules for the generated
|
||||||
library, added to the standard deps if standard_runtime_deps is True.
|
library, added to the standard deps if standard_runtime_deps is True.
|
||||||
tags: tags to apply to subsidiary build rules.
|
tags: tags to apply to subsidiary build rules.
|
||||||
@ -189,17 +190,17 @@ def tf_library(
|
|||||||
else:
|
else:
|
||||||
profiling_flag = ""
|
profiling_flag = ""
|
||||||
|
|
||||||
if mlir_components:
|
mlir_flag = "--mlir_components=" + mlir_components
|
||||||
mlir_flag = "--mlir_components=" + mlir_components
|
|
||||||
else:
|
srcs = [tfcompile_graph, config]
|
||||||
mlir_flag = ""
|
debug_info_flag = ""
|
||||||
|
if debug_info:
|
||||||
|
srcs.append(debug_info)
|
||||||
|
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||||
|
|
||||||
native.genrule(
|
native.genrule(
|
||||||
name = ("gen_" + name),
|
name = ("gen_" + name),
|
||||||
srcs = [
|
srcs = srcs,
|
||||||
tfcompile_graph,
|
|
||||||
config,
|
|
||||||
],
|
|
||||||
outs = [
|
outs = [
|
||||||
header_file,
|
header_file,
|
||||||
metadata_object_file,
|
metadata_object_file,
|
||||||
@ -209,6 +210,7 @@ def tf_library(
|
|||||||
"CUDA_VISIBLE_DEVICES='' " +
|
"CUDA_VISIBLE_DEVICES='' " +
|
||||||
"$(location " + tfcompile_tool + ")" +
|
"$(location " + tfcompile_tool + ")" +
|
||||||
" --graph=$(location " + tfcompile_graph + ")" +
|
" --graph=$(location " + tfcompile_graph + ")" +
|
||||||
|
debug_info_flag +
|
||||||
" --config=$(location " + config + ")" +
|
" --config=$(location " + config + ")" +
|
||||||
" --entry_point=" + ep +
|
" --entry_point=" + ep +
|
||||||
" --cpp_class=" + cpp_class +
|
" --cpp_class=" + cpp_class +
|
||||||
@ -240,10 +242,7 @@ def tf_library(
|
|||||||
session_module_pb = name + "_session_module.pb"
|
session_module_pb = name + "_session_module.pb"
|
||||||
native.genrule(
|
native.genrule(
|
||||||
name = (name + "_session_module"),
|
name = (name + "_session_module"),
|
||||||
srcs = [
|
srcs = srcs,
|
||||||
tfcompile_graph,
|
|
||||||
config,
|
|
||||||
],
|
|
||||||
outs = [
|
outs = [
|
||||||
session_module_pb,
|
session_module_pb,
|
||||||
],
|
],
|
||||||
@ -251,6 +250,7 @@ def tf_library(
|
|||||||
"CUDA_VISIBLE_DEVICES='' " +
|
"CUDA_VISIBLE_DEVICES='' " +
|
||||||
"$(location " + tfcompile_tool + ")" +
|
"$(location " + tfcompile_tool + ")" +
|
||||||
" --graph=$(location " + tfcompile_graph + ")" +
|
" --graph=$(location " + tfcompile_graph + ")" +
|
||||||
|
debug_info_flag +
|
||||||
" --config=$(location " + config + ")" +
|
" --config=$(location " + config + ")" +
|
||||||
" --entry_point=" + ep +
|
" --entry_point=" + ep +
|
||||||
" --cpp_class=" + cpp_class +
|
" --cpp_class=" + cpp_class +
|
||||||
@ -410,5 +410,6 @@ def target_llvm_triple():
|
|||||||
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
|
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
|
||||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||||
"//tensorflow:macos": "x86_64-none-darwin",
|
"//tensorflow:macos": "x86_64-none-darwin",
|
||||||
|
"//tensorflow:windows": "x86_64-none-windows",
|
||||||
"//conditions:default": "x86_64-pc-linux",
|
"//conditions:default": "x86_64-pc-linux",
|
||||||
})
|
})
|
||||||
|
@ -65,7 +65,7 @@ int main(int argc, char** argv) {
|
|||||||
flags.out_metadata_object = "out_helper.o";
|
flags.out_metadata_object = "out_helper.o";
|
||||||
flags.out_header = "out.h";
|
flags.out_header = "out.h";
|
||||||
flags.entry_point = "entry";
|
flags.entry_point = "entry";
|
||||||
flags.tensorflow_header_root = "third_party/tensorflow";
|
flags.debug_info_path_begin_marker = "";
|
||||||
|
|
||||||
std::vector<tensorflow::Flag> flag_list;
|
std::vector<tensorflow::Flag> flag_list;
|
||||||
AppendMainFlags(&flag_list, &flags);
|
AppendMainFlags(&flag_list, &flags);
|
||||||
@ -82,12 +82,10 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||||
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
|
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
|
||||||
"other than flags\n\n"
|
"other than flags. See --help.\n\n";
|
||||||
<< usage;
|
|
||||||
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
|
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
|
||||||
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
||||||
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
|
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n";
|
||||||
<< usage;
|
|
||||||
return 1;
|
return 1;
|
||||||
} else {
|
} else {
|
||||||
TF_QCHECK_OK(status);
|
TF_QCHECK_OK(status);
|
||||||
|
@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
|
|||||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||||
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||||
|
load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [":internal"],
|
default_visibility = [":internal"],
|
||||||
@ -13,6 +14,10 @@ package_group(
|
|||||||
includes = [
|
includes = [
|
||||||
"//tensorflow/compiler/tf2xla:internal",
|
"//tensorflow/compiler/tf2xla:internal",
|
||||||
],
|
],
|
||||||
|
packages = [
|
||||||
|
"//tensorflow/compiler/tests/...",
|
||||||
|
"//tensorflow/python/...",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
@ -56,6 +61,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
@ -69,6 +75,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = if_cuda_or_rocm([
|
deps = if_cuda_or_rocm([
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
@ -77,19 +84,6 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "xla_mlir_gpu_jit",
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = if_cuda_or_rocm([
|
|
||||||
":jit_compilation_passes",
|
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
|
||||||
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
|
|
||||||
]),
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_cpu_device",
|
name = "xla_cpu_device",
|
||||||
srcs = ["xla_cpu_device.cc"],
|
srcs = ["xla_cpu_device.cc"],
|
||||||
@ -124,6 +118,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:gpu_init",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -168,7 +163,9 @@ XLA_DEVICE_DEPS = [
|
|||||||
":common",
|
":common",
|
||||||
":xla_launch_util",
|
":xla_launch_util",
|
||||||
":xla_tensor",
|
":xla_tensor",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||||
@ -187,6 +184,7 @@ XLA_DEVICE_DEPS = [
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:functional_ops_op_lib",
|
"//tensorflow/core:functional_ops_op_lib",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
@ -261,13 +259,26 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Internal targets below this point.
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "flags",
|
name = "flags",
|
||||||
srcs = ["flags.cc"],
|
srcs = ["flags.cc"],
|
||||||
hdrs = ["flags.h"],
|
hdrs = ["flags.h"],
|
||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Header-only version of "flags" library, for linking from the shared object
|
||||||
|
# without ODR violations.
|
||||||
|
cc_library(
|
||||||
|
name = "flags_headers_only",
|
||||||
|
hdrs = ["flags.h"],
|
||||||
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
@ -287,6 +298,8 @@ cc_library(
|
|||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Internal targets below this point.
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_launch_util",
|
name = "xla_launch_util",
|
||||||
srcs = ["xla_launch_util.cc"],
|
srcs = ["xla_launch_util.cc"],
|
||||||
@ -325,6 +338,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_activity_listener",
|
":xla_activity_listener",
|
||||||
":xla_activity_proto_cc",
|
":xla_activity_proto_cc",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -408,6 +422,7 @@ cc_library(
|
|||||||
"xla_kernel_creator.h",
|
"xla_kernel_creator.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":xla_kernel_creator_util",
|
":xla_kernel_creator_util",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
@ -636,6 +651,7 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
@ -767,7 +783,7 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
|
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
|
||||||
# error.
|
# error.
|
||||||
tags = ["nomsan"],
|
tags = ["nomsan"] + tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":compilation_passes",
|
":compilation_passes",
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user