Merge branch 'upstream/master' into addsub_16x8

Change-Id: I51baab7d117fb4656fa09f4b9ff78ef34cbbe8a0
This commit is contained in:
Elena Zhelezina 2020-03-26 12:23:04 +00:00
commit 4180f945a7
6265 changed files with 342917 additions and 134534 deletions

104
.bazelrc
View File

@ -46,7 +46,6 @@
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support.
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
# tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
# numa: Enable numa using hwloc.
@ -69,6 +68,7 @@
# rbe_linux_py3: Linux Python 3 RBE config
#
# rbe_win_py37: Windows Python 3.7 RBE config
# rbe_win_py38: Windows Python 3.8 RBE config
#
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
@ -136,15 +136,9 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -221,6 +215,11 @@ build --define=grpc_no_ares=true
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
# https://github.com/tensorflow/community/pull/179
build --noincompatible_prohibit_aapt1
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
@ -241,6 +240,7 @@ build:windows --copt=/w
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined.
build:windows --copt=/D_USE_MATH_DEFINES
build:windows --host_copt=/D_USE_MATH_DEFINES
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr
@ -313,22 +313,26 @@ build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
# Options when using remote execution
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
# Flag to enable remote config
common --experimental_repo_remote_exec
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --auth_enabled=true
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
build:rbe --google_default_credentials
build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_best_effort=false
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
build:rbe --distinct_host_configuration=false
build:rbe --flaky_test_attempts=3
build:rbe --jobs=200
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
build:rbe --remote_timeout=3600
build:rbe --spawn_strategy=remote,worker,standalone,local
test:rbe --test_env=USER=anon
build:rbe --distinct_host_configuration=false
# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:rbe --remote_download_toplevel
build:rbe_linux --config=rbe
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
@ -343,6 +347,7 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
@ -351,21 +356,37 @@ build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/to
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --config=rbe_linux
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_CUDA=1
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
@ -375,29 +396,33 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
build:rbe_linux_py3 --config=rbe_linux
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
@ -407,7 +432,6 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
build:tensorflow_testing_rbe_win --config=rbe_win
# END TF REMOTE BUILD EXECUTION OPTIONS
# Default options should come above this line

View File

@ -1 +1 @@
1.2.1
2.0.0

44
.github/ISSUE_TEMPLATE/00-bug-issue.md vendored Normal file
View File

@ -0,0 +1,44 @@
---
name: Bug Issue
about: Use this template for reporting a bug
labels: 'type:bug'
---
<em>Please make sure that this is a bug. As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

View File

@ -1,35 +0,0 @@
---
name: Bug/Performance Issue
about: Use this template for reporting a bug or a performance issue.
---
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
**Other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -1,6 +1,7 @@
---
name: Build/Installation Issue
about: Use this template for build/installation issues
labels: 'type:build/install'
---

View File

@ -1,10 +1,11 @@
---
name: Documentation Issue
about: Use this template for documentation related
about: Use this template for documentation related issues
labels: 'type:docs'
---
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
policy, we only address code/doc bugs, performance issues, feature requests, and
build/installation issues on GitHub.

View File

@ -1,6 +1,7 @@
---
name: Feature Request
about: Use this template for raising a feature request
labels: 'type:feature'
---

View File

@ -1,10 +1,10 @@
---
name: TensorFlow Lite Op Request
about: Use this template for reporting ops you are using or missing.
about: Use this template for reporting Lite ops you are using or missing
labels: 'comp:lite'
---
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
@ -17,8 +17,14 @@ about: Use this template for reporting ops you are using or missing.
# Copy and paste here
```
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Also, please include a link to a GraphDef or the model if possible.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

View File

@ -1,6 +1,7 @@
---
name: Other Issues
about: Use this template for any other non-support related issues
labels: 'type:others'
---

View File

@ -1,6 +1,7 @@
---
name: TensorFlow Lite New Converter Issue
about: Use this template for reporting issues during model conversion to TFLite.
about: Use this template for reporting issues during model conversion to TFLite
labels: 'TFLiteConverter'
---
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
**Command used to run the converter or code if youre using the Python API**
If possible, please share a link to Colab/Jupyter/any notebook.
```
# Copy and paste here the exact command

View File

@ -0,0 +1,45 @@
---
name: Performance Issue
about: Use this template for reporting a performance issue
labels: 'type:performance'
---
<em>Please make sure that this is an issue related to performance of TensorFlow.
As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

5
.gitignore vendored
View File

@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
/tensorflow/python/framework/fast_tensor_util.cpp
/tensorflow/lite/gen/**
/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/tools/make/gen/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
*.whl
@ -37,7 +38,9 @@ gradleBuild
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
/tensorflow/lite/**/ios/BUILD
/tensorflow/lite/**/objc/BUILD
/tensorflow/lite/**/swift/BUILD
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/examples/ios/simple/data/*.txt
Podfile.lock

1
.pylintrc Symbolic link
View File

@ -0,0 +1 @@
tensorflow/tools/ci_build/pylintrc

View File

@ -70,7 +70,7 @@ $ python
3
>>> hello = tf.constant('Hello, TensorFlow!')
>>> hello.numpy()
'Hello, TensorFlow!'
b'Hello, TensorFlow!'
```
For more examples, see the
@ -130,18 +130,20 @@ Build Type | Status
## Resources
* [TensorFlow.org](https://www.tensorflow.org)
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow examples](https://github.com/tensorflow/examples)
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow Examples](https://github.com/tensorflow/examples)
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [TensorFlow blog](https://blog.tensorflow.org)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the
[TensorFlow community](https://www.tensorflow.org/community) and how to

View File

@ -1,3 +1,19 @@
# Release 2.0.1
## Bug Fixes and Other Changes
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
# Release 1.15.2
## Bug Fixes and Other Changes
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
# Release 2.1.0
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.

View File

@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).

View File

@ -1,13 +1,11 @@
workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//third_party:repo.bzl", "tf_http_archive")
tf_http_archive(
http_archive(
name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
@ -115,3 +113,32 @@ http_archive(
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")
load("//third_party/googleapis:repository_rules.bzl", "config_googleapis")
config_googleapis()

View File

@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MAX_BAZEL_VERSION = '1.2.1'
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -1155,7 +1155,7 @@ def set_trisycl_include_dir(environ_cp):
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def system_specific_test_config(env):
def system_specific_test_config(environ_cp):
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
write_to_bazelrc('test --test_size_filters=small,medium')
@ -1171,14 +1171,14 @@ def system_specific_test_config(env):
test_only_filters = ['-oss_serial']
if is_windows():
test_and_build_filters.append('-no_windows')
if env.get('TF_NEED_CUDA', None) == '1':
if environ_cp.get('TF_NEED_CUDA', None) == '1':
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
else:
test_and_build_filters.append('-gpu')
elif is_macos():
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
elif is_linux():
if env.get('TF_NEED_CUDA', None) == '1':
if environ_cp.get('TF_NEED_CUDA', None) == '1':
test_and_build_filters.append('-no_gpu')
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
else:
@ -1221,7 +1221,7 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/g/build/c/SsW98Eo7l3o
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be
@ -1390,9 +1390,8 @@ def main():
else:
environ_cp['TF_CONFIGURE_IOS'] = '0'
xla_enabled_by_default = is_linux() or is_macos()
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
xla_enabled_by_default, 'xla')
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla')
set_action_env_var(
environ_cp,
@ -1523,7 +1522,7 @@ def main():
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
system_specific_test_config(os.environ)
system_specific_test_config(environ_cp)
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
if environ_cp.get('TF_CONFIGURE_IOS') == '1':

View File

@ -187,6 +187,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "fuchsia",
values = {"cpu": "fuchsia"},
visibility = ["//visibility:public"],
)
config_setting(
name = "ios_x86_64",
values = {
@ -448,19 +454,66 @@ config_setting(
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later.
selects.config_setting_group(
name = "mobile",
match_any = [
":android",
":chromiumos",
":emscripten",
":ios",
],
)
config_setting(
name = "lite_protos_legacy",
values = {"define": "TENSORFLOW_PROTOS=lite"},
visibility = ["//visibility:private"],
)
config_setting(
name = "full_protos",
values = {"define": "TENSORFLOW_PROTOS=full"},
visibility = ["//visibility:public"],
)
selects.config_setting_group(
name = "lite_protos",
match_any = [":lite_protos_legacy"],
)
selects.config_setting_group(
name = "mobile_lite_protos",
match_all = [
":lite_protos",
":mobile",
],
)
selects.config_setting_group(
name = "mobile_full_protos",
match_all = [
":full_protos",
":mobile",
],
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
],
)
@ -479,6 +532,7 @@ bzl_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/platform:build_config_root_bzl",
"//tensorflow/core/platform:rules_cc_bzl",
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl",
@ -493,8 +547,8 @@ cc_library(
name = "grpc",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc_unsecure"],
"//conditions:default": ["@grpc"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
}),
)
@ -502,8 +556,8 @@ cc_library(
name = "grpc++",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc++_unsecure"],
"//conditions:default": ["@grpc//:grpc++"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
}),
)
@ -588,6 +642,7 @@ tf_cc_shared_object(
"//tensorflow/core:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler:profiler_impl",
"//tensorflow/stream_executor:stream_executor_impl",
"//tensorflow:tf_framework_version_script.lds",
] + tf_additional_binary_deps(),
@ -647,6 +702,7 @@ tf_cc_shared_object(
"//tensorflow/c:exported_symbols.lds",
"//tensorflow/c:version_script.lds",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:tensorflow",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
],
@ -907,7 +963,6 @@ py_library(
"//conditions:default": [":tf_python_api_gen_v1"],
}) + [
":root_init_gen",
":virtual_root_init_gen",
"//tensorflow/python/keras/api:keras_python_api_gen",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",

View File

@ -35,9 +35,11 @@ import inspect as _inspect
import logging as _logging
import os as _os
import site as _site
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
@ -69,13 +71,13 @@ except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v2 import keras
@ -85,6 +87,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top

View File

@ -22,12 +22,14 @@ import distutils as _distutils
import inspect as _inspect
import os as _os
import site as _site
import six as _six
import sys as _sys
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v1 import keras
@ -80,6 +83,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """

View File

@ -54,9 +54,10 @@ filegroup(
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"python_api.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
@ -98,6 +99,17 @@ tf_cuda_library(
],
)
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
@ -142,7 +154,10 @@ tf_cuda_library(
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
visibility = [
"//tensorflow/c:__subpackages__",
"//third_party/llvm/llvm-project:__subpackages__",
],
deps = [
":c_api_internal",
":tf_attrtype",
@ -230,7 +245,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -303,7 +318,6 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"@com_google_absl//absl/strings",
@ -525,6 +539,7 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:math",
"//tensorflow/core/platform:resource_loader",
],
)
@ -537,6 +552,7 @@ tf_cc_test(
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = ["notsan"], # b/149031034
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@ -635,6 +651,7 @@ tf_cuda_cc_test(
deps = [
":c_api",
":kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -642,6 +659,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3",
"@com_google_absl//absl/container:inlined_vector",
],
)
@ -683,4 +701,5 @@ tf_cuda_library(
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
"//tensorflow/python:cpp_shape_inference_proto_cc",
],
alwayslink = 1,
)

View File

@ -774,7 +774,7 @@ extern "C" {
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
const char* op_type,
const char* oper_name)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
return new TF_OperationDescription(graph, op_type, oper_name);
}
@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
Node* ret = nullptr;
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
const TF_ImportGraphDefOptions* opts,
TF_ImportGraphDefResults* tf_results,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
const int last_node_id = graph->graph.num_node_ids();
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
@ -32,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
@ -520,72 +520,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
auto* status = TF_NewStatus();
TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::Tensor dst;
TF_CHECK_OK(TF_TensorToTensor(t, &dst));
LOG(INFO) << dst.DebugString();
TF_DeleteTensor(t);
TF_DeleteStatus(status);
}
void TFE_OpPrintDebugString(TFE_Op* op) {
VLOG(1) << "TFE_OpPrintDebugString() over " << op;
LOG(INFO) << op->operation.DebugString();
}
struct TFE_ExecuteOpNotification {
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
tensorflow::Notification n;
std::unique_ptr<tensorflow::Thread> thread;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
};
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TFE_TensorHandle** retvals,
int* num_retvals,
TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get());
n->n.Notify();
}));
return n;
}
void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status) {
if (notification == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification is a nullptr.");
return;
}
if (notification->thread == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification didn't start a thread correctly. Cleaning up "
"this notification. Please re-execute the operation to get a new "
"notification.");
delete notification;
return;
}
notification->n.WaitForNotification();
status->status = notification->status->status;
delete notification;
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
@ -810,113 +744,6 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
status->status = EnableCollectiveOps(server_def, ctx);
}
void MakeTPUInitializationFunctionDef(
const tensorflow::string& tpu_system_device_name,
tensorflow::FunctionDef* function_def) {
tensorflow::OpDef* signature_def(function_def->mutable_signature());
signature_def->set_name("_eager_context_tpu_initialization");
signature_def->set_is_stateful(true);
signature_def->add_control_output("ConfigureDistributedTPU");
tensorflow::OpDef_ArgDef* arg_def(signature_def->add_output_arg());
arg_def->set_name("topology_proto");
arg_def->set_type(tensorflow::DataType::DT_STRING);
tensorflow::NodeDef* configure_node_def(function_def->add_node_def());
configure_node_def->set_name("ConfigureDistributedTPU");
configure_node_def->set_op("ConfigureDistributedTPU");
(*configure_node_def->mutable_attr())["compilation_failure_closes_chips"]
.set_b(false);
configure_node_def->set_device(tpu_system_device_name);
tensorflow::NodeDef* identity_node_def(function_def->add_node_def());
identity_node_def->set_name("Identity");
identity_node_def->set_op("Identity");
identity_node_def->add_input("ConfigureDistributedTPU:topology:0");
(*identity_node_def->mutable_attr())["T"].set_type(
tensorflow::DataType::DT_STRING);
(*function_def->mutable_ret())["topology_proto"] = "Identity:output:0";
(*function_def->mutable_control_ret())["ConfigureDistributedTPU"] =
"ConfigureDistributedTPU";
}
// NOTE(iga): ConfigureDistributedTPU is dummy op whose sole purpose is to
// trigger DistributedTPURewritePass. This pass actually adds real ops that
// initialize the TPU system. Thus, we can't simply run ConfigureDistributedTPU
// eagerly. We need to wrap it in a function and trigger the rewrite passes on
// it. The easiest way to trigger a rewrite is to run it in a function.
// Running initialization as an operation rather than calling the underlying C++
// implementation directly allows us to run initialization on a remote device
// without a separate communication channel.
TF_CAPI_EXPORT extern void TFE_InitializeTPUSystem(TFE_Context* ctx,
const char* job,
TF_Buffer* tpu_topology,
TF_Status* status) {
if (tpu_topology->data != nullptr) {
status->status = InvalidArgument("Passing non-empty TF_Buffer is invalid.");
return;
}
tensorflow::string tpu_system_device_name = tensorflow::strings::StrCat(
"/job:", job, "/replica:0/task:0/device:TPU_SYSTEM:0");
tensorflow::Device* tpu_system_device = nullptr;
tensorflow::Status lookup_status = ctx->context->FindDeviceFromName(
tpu_system_device_name.c_str(), &tpu_system_device);
if (!lookup_status.ok() || tpu_system_device == nullptr) {
// There are no TPUs to initialize.
status->status = tensorflow::errors::NotFound(tensorflow::strings::StrCat(
"No TPUs are associated with the specified job '", job, "'"));
return;
}
tensorflow::FunctionDef function_def;
MakeTPUInitializationFunctionDef(tpu_system_device->name().c_str(),
&function_def);
tensorflow::string function_name = function_def.signature().name();
status->status = ctx->context->AddFunctionDef(function_def);
if (!status->status.ok()) return;
// Run the function, which may be a remote call. It returns a serialized
// topology proto.
const tensorflow::AttrTypeMap* attr_map;
bool is_function;
status->status = tensorflow::AttrTypeMapForOp(function_name.c_str(),
&attr_map, &is_function);
if (!status->status.ok()) return;
tensorflow::EagerOperation call_op(ctx->context, function_name.c_str(),
is_function, attr_map);
status->status = call_op.SetDeviceName(tpu_system_device_name.c_str());
if (!status->status.ok()) return;
tensorflow::TensorHandle* remote_topology_handle;
int num_retvals = 1;
status->status =
tensorflow::EagerExecute(&call_op, &remote_topology_handle, &num_retvals);
if (!status->status.ok()) return;
tensorflow::TensorHandle* local_topology_handle = nullptr;
status->status = tensorflow::EagerCopyToDevice(
remote_topology_handle, ctx->context, &ctx->context->Executor(),
ctx->context->HostCPU(), false, &local_topology_handle);
remote_topology_handle->Unref();
if (!status->status.ok()) return;
const tensorflow::Tensor* topology_proto_tensor;
status->status = local_topology_handle->Tensor(&topology_proto_tensor);
if (!status->status.ok()) return;
status->status = ctx->context->RemoveFunction(function_name);
if (!status->status.ok()) return;
// The function ran, so we put the result in the return buffer.
tensorflow::string result =
topology_proto_tensor->flat<tensorflow::tstring>()(0);
local_topology_handle->Unref();
void* topology_data = tensorflow::port::Malloc(result.size());
tpu_topology->data = topology_data;
if (tpu_topology->data == nullptr) {
status->status = tensorflow::errors::ResourceExhausted(
"Failed to allocate memory for topology proto (", result.size(),
" bytes)");
}
memcpy(topology_data, result.c_str(), result.size());
tpu_topology->length = result.size();
tpu_topology->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
result->num_items = num_items;
@ -990,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
node_def.set_name(tfe_op->operation.Name());
node_def.set_op(tfe_op->operation.Name());
node_def.set_name(tfe_op->operation->Name());
node_def.set_op(tfe_op->operation->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
tensorflow::down_cast<tensorflow::OperationInterface*>(
tfe_op->operation.get())
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =

View File

@ -188,31 +188,6 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
// Prints `handle` in a human readable format to standard output for debugging.
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op);
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
// Allows invoking a kernel asynchronously, and explicitly returns a
// notification that can be waited upon. This always executes the kernel in a
// new thread.
// 1. `retvals` and `num_retvals` can only be consumed after
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
// if the return is unsuccessful
// 2. These new APIs cannot be used together with the TFE context level async
// support.
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status);
// Waits to complete the op execution, and cleans up the notification.
// Errors reported by op execution are set in `status`.
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);
@ -297,20 +272,6 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
size_t proto_len,
TF_Status* status);
// Runs operations necessary to initialize TPU devices associated with `job`
// (e.g. "localhost" for local TPUs), returning a serialized TopologyProto (same
// result as the "ConfigureDistributedTPU" operation) if TPUs were
// available. Sets a NotFound status if no TPUs were found associated with
// the job specified.
//
// TFE_InitializeTPUSystem should only be run once for a given TPU system;
// running it multiple times will invalidate tensors/variables placed on the
// affected TPUs.
TF_CAPI_EXPORT extern void TFE_InitializeTPUSystem(TFE_Context* ctx,
const char* job,
TF_Buffer* tpu_topology,
TF_Status* status);
// Information about the shape of a Tensor and its type.
struct TF_ShapeAndType {
// Number of dimensions. -1 indicates unknown rank.

View File

@ -73,21 +73,6 @@ protocol: "grpc"
TF_DeleteStatus(status);
}
TEST(CAPI_EXPERIMENTAL, InitializeTPUSystemTest) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TF_Buffer* buf = TF_NewBuffer();
TFE_InitializeTPUSystem(ctx, "localhost", buf, status);
// Note that this assumes TPUs are not available for this test.
CHECK_EQ(TF_NOT_FOUND, TF_GetCode(status)) << TF_Message(status);
TF_DeleteBuffer(buf);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
TEST(CAPI_EXPERIMENTAL, IsStateful) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -99,127 +84,6 @@ TEST(CAPI_EXPERIMENTAL, IsStateful) {
EXPECT_EQ(id, 0);
}
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
auto* r =
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteOp(matmul_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
// Perform a send/recv test. Recv blocks, so they need to be executed
// asynchronously.
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
TFE_TensorHandle* m = TestMatrixTensorHandle();
// Build a send op.
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(send_op, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
string tensor_name = "Tensor";
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(send_op, "client_terminated", true);
// Build a recv op.
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
TFE_TensorHandle* send_retvals;
int send_num_retvals = 0;
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
&send_num_retvals, status);
TFE_TensorHandle* recv_retvals[1] = {nullptr};
int recv_num_retvals = 1;
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
&recv_num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(1, product[0]);
EXPECT_EQ(2, product[1]);
EXPECT_EQ(3, product[2]);
EXPECT_EQ(4, product[3]);
TFE_DeleteOp(send_op);
TFE_DeleteOp(recv_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(recv_retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
class ShapeInferenceTest : public ::testing::Test {
protected:
ShapeInferenceTest()

View File

@ -51,7 +51,7 @@ Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
@ -87,7 +87,7 @@ Status ProcessInputs(
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
@ -111,7 +111,7 @@ Status ComputeBodyNodes(
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);

View File

@ -71,14 +71,14 @@ struct TF_Graph {
TF_Graph();
tensorflow::mutex mu;
tensorflow::Graph graph GUARDED_BY(mu);
tensorflow::Graph graph TF_GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);
TF_GUARDED_BY(mu);
// The keys of this map are all the active sessions using this graph. Each
// value records whether the graph has been mutated since the corresponding
@ -94,8 +94,8 @@ struct TF_Graph {
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
// status, this should be reverted when possible.
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
TF_GUARDED_BY(mu);
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
@ -123,7 +123,7 @@ struct TF_Session {
tensorflow::Session* session;
TF_Graph* const graph;
tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
int last_num_graph_nodes;
// If true, TF_SessionRun and similar methods will call
@ -169,9 +169,9 @@ struct TF_ApiDefMap {
}
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
bool update_docs_called GUARDED_BY(lock);
bool update_docs_called TF_GUARDED_BY(lock);
tensorflow::mutex lock;
};
@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
LOCKS_EXCLUDED(session->graph->mu, session->mu);
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
std::string getTF_OutputDebugString(TF_Output node);

View File

@ -45,6 +45,8 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
{
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
string lib_path = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
@ -1350,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
TEST(CAPI, SavedModel) {
// Load the saved model.
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
TF_Buffer* metagraph = TF_NewBuffer();
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
}
TEST(CAPI, SavedModelNullArgsAreValid) {
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Status* s = TF_NewStatus();
const char* tags[] = {tensorflow::kSavedModelTagServe};

View File

@ -2,6 +2,7 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
"tf_cuda_cc_test",
"tf_cuda_library",
@ -26,8 +27,9 @@ tf_cuda_library(
"c_api.cc",
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h",
"operation_interface.cc",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
@ -56,6 +58,7 @@ tf_cuda_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:errors",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
@ -82,18 +85,18 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
visibility = [
@ -106,6 +109,7 @@ tf_cuda_library(
name = "c_api_internal",
srcs = [
"c_api_experimental.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"],
@ -128,22 +132,9 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
"@com_google_absl//absl/container:fixed_array",
],
)
@ -215,6 +206,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform:casts",
"@com_google_absl//absl/strings",
],
)
@ -272,8 +264,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
@ -303,6 +293,27 @@ tf_cuda_cc_test(
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
srcs = [
"custom_device_test.cc",
],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tape",
hdrs = ["tape.h"],
@ -315,10 +326,37 @@ cc_library(
filegroup(
name = "headers",
srcs = ["c_api.h"],
srcs = [
"c_api.h",
"c_api_experimental.h",
"dlpack.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@dlpack",
],
alwayslink = 1,
)
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite)
@ -332,6 +370,7 @@ filegroup(
exclude = [
"c_api_experimental.cc",
"*test*",
"*dlpack*",
],
),
visibility = ["//visibility:public"],

File diff suppressed because it is too large Load Diff

View File

@ -213,7 +213,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info);
// Returns the number of dimensions used to represent the tensor on its device.
// The number of dimensions used to reprensent the tensor on device can be
// The number of dimensions used to represent the tensor on device can be
// different from the number returned by TFE_TensorHandleNumDims.
// The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(

View File

@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle_->device();
tensorflow::Device* device = absl::get<Device*>(handle_->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =

View File

@ -18,62 +18,27 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
op_to_reset);
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
}
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
return profiler->profiler->Status().ok();
}
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
TF_Status* status) {
string content;
status->status = profiler->profiler->SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
buf->data = data;
buf->length = content.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
void TFE_StartProfilerServer(int port) {
// Release child thread intentionally. The child thread can be terminated by
// terminating the main thread.
tensorflow::StartProfilerServer(port).release();
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
}
@ -82,46 +47,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
}
bool TFE_ProfilerClientStartTracing(const char* service_addr,
const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::client::StartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
result->length = content.length();
result->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
tensorflow::Set_TF_Status_from_Status(status, s);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);
@ -589,8 +514,7 @@ void TFE_DeleteCancellationManager(
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
op->operation.SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = op->operation->SetCancellationManager(cancellation_manager);
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
@ -619,3 +543,41 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
ctx->context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
TF_Status* status) {
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
auto* function_def = ctx->context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
return;
}
string str = function_def->SerializeAsString();
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}

View File

@ -27,41 +27,12 @@ extern "C" {
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
// does not set the device name. If it's not `NULL`, then it attempts to parse
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than seperately calling it because if the existing op has the same
// than separately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name,
const char* raw_device_name,
TF_Status* status, TFE_Op* op_to_reset);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
// A profiler which will start profiling when creating the object and will stop
// when the object is destroyed. It will profile all operations run under the
// given TFE_Context. Multiple instance of it can be created, but at most one
// of them will profile for each TFE_Context.
// Thread-safety: TFE_Profiler is thread-safe.
typedef struct TFE_Profiler TFE_Profiler;
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
// The output string is a binary string of tensorflow.tpu.Trace. User can write
// the string to file for offline analysis by tensorboard.
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
TF_Buffer* buf,
TF_Status* status);
// Start a profiler grpc server which listens to specified port. It will start
// the server on its own thread. It can be shutdown by terminating tensorflow.
// It can be used in both Eager mode and graph mode. Creating multiple profiler
// server is allowed. The service defined in
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
TF_Status* status);
// Enables only graph collection in RunMetadata on the functions executed from
// this context.
@ -71,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
// this context.
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// profiling and save the result into logdir which can be visualized by
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
// include_dataset_opts to false to profile longer traces. It will block the
// caller thread until receives tracing result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
TF_Status* status);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// monitoring and return the result in a string. It will block the
// caller thread until receiving the monitoring result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer* result, TF_Status* status);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
@ -434,6 +382,18 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Sync pending nodes in local executors (including the context default executor
// and thread executors) and streaming requests to remote executors, and get the
// combined status.
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status);
// If the TensorHandle is copied to another device as part of an op execution,
// the copy is destroyed after the op has executed. Enabling implicit mirroring
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
TFE_TensorHandle*, TF_Status*);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
@ -458,6 +418,114 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status);
// Retrieves the address space (i.e. job, replia, task) of the local host and
// saves it in the buffer.
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
// APIs for generically dealing with op attributes (e.g. when forwarding them
// through custom device implementations).
//
// TODO(allenl): Currently these are black boxes, but we should have some way to
// inspect values. This would let people e.g. copy over most attributes and then
// modify some based on their values.
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a struct with a reference to information about attributes of `op`.
//
// The `attrs` struct does not own any memory, and `op` must outlive it.
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
// containing the op name and a map of its attributes.
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
TF_Buffer* buf,
TF_Status* status);
// Set an op's attribute from a serialized AttrValue protocol buffer.
//
// Analogous to TF_SetAttrValueProto for building graph operations.
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
const char* attr_name,
const void* proto,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 2
// Struct to be filled in
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info);
// Method to execute an operation.
void (*execute)(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
} TFE_CustomDevice;
// Registers a custom device for use with eager execution.
//
// Eager operations may be placed on this device, e.g. `with
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
//
// The custom device defines copy operations for moving TensorHandles on and
// off, and an an execution operation for named operations. Often execution will
// simply wrap op execution on one or more physical devices.
//
// device_info is an opaque caller-defined type stored with the custom device
// which is passed to the functions referenced in the TFE_CustomDevice struct
// `device` (execute, delete_device, etc.). It can for example contain the
// names of wrapped devices.
//
// There are currently no graph semantics implemented for registered custom
// devices, so executing tf.functions which contain operations placed on custom
// devices will fail.
//
// `device_name` must not name an existing physical or custom device. It must
// follow the format:
//
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
//
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
// the device is not usable. In case of a bad status, `device.delete_device` is
// still called on `device_info` (i.e. the caller does not retain ownership).
//
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/trace_events.pb.h"
using tensorflow::string;
@ -39,88 +38,6 @@ static bool HasSubstr(absl::string_view base, absl::string_view substr) {
return ok;
}
void ExecuteWithProfiling(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_Profiler* profiler = TFE_NewProfiler();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
// Run op on GPU if it is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
}
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Buffer* profiler_result = TF_NewBuffer();
if (async) {
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
}
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
TFE_DeleteProfiler(profiler);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
profiler::Trace profile_proto;
EXPECT_TRUE(profile_proto.ParseFromString(
{reinterpret_cast<const char*>(profiler_result->data),
profiler_result->length}));
string profile_proto_str = profile_proto.DebugString();
#ifndef TENSORFLOW_USE_ROCM
// TODO(rocm): enable once GPU profiling is supported in ROCm mode
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
}
#endif
// "/host:CPU" is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
TF_DeleteBuffer(profiler_result);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); }
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
TEST(CAPI, MultipleProfilerSession) {
TFE_Profiler* profiler1 = TFE_NewProfiler();
EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
TFE_Profiler* profiler2 = TFE_NewProfiler();
EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
TFE_DeleteProfiler(profiler1);
TFE_DeleteProfiler(profiler2);
}
TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus();
auto* counter =

View File

@ -1,66 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}

View File

@ -27,12 +27,12 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
@ -48,7 +48,6 @@ limitations under the License.
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/public/version.h"
struct TFE_ContextOptions {
@ -89,53 +88,8 @@ struct TFE_TensorDebugInfo {
std::vector<tensorflow::int64> dev_dims;
};
struct TFE_OpInferenceContext {
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
: op_def(op_def) {}
const tensorflow::OpDef* op_def; // op definition from protobuf
int input_arg_idx = 0; // arg definition index for the next input to be added
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
};
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: ctx(ctx),
operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
void AddInput(TFE_TensorHandle* input, TF_Status* status);
void Execute(TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status);
TFE_Context* ctx;
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
std::unique_ptr<tensorflow::ProfilerSession> profiler;
std::unique_ptr<AbstractOperationInterface> operation;
};
struct TFE_MonitoringCounterCell {
@ -282,4 +236,17 @@ struct TFE_Executor {
tensorflow::EagerExecutor* unowned_executor;
};
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
const char* op_name)
: name(op_name), attributes(value) {}
const char* name;
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async) {
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
TFE_OpSetDevice(matmul, task1_name, status);
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
}
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h1_task2->handle.get())
->Handle();
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->GetInput(1), remote_arg);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
worker_server2.release();
}
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
TestRemoteExecuteSilentCopies(true, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {

View File

@ -17,12 +17,15 @@ limitations under the License.
#include <string.h>
#include <string>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@ -363,34 +366,63 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
TensorHandleCopyBetweenTwoGPUDevices(true);
}
void TensorHandleSilentCopy(bool async) {
void TensorHandleSilentCopy(bool async,
TFE_ContextDevicePlacementPolicy global_policy,
TFE_ContextDevicePlacementPolicy thread_policy,
bool cpu_op) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
if (thread_policy != global_policy) {
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
}
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
if (cpu_op) {
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
} else {
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
}
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hcpu->handle.get())
->Handle();
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hgpu->handle.get())
->Handle();
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->GetInput(0), arg0);
EXPECT_EQ(op->GetInput(1), arg1);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
@ -404,57 +436,21 @@ void TensorHandleSilentCopy(bool async) {
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
void TensorHandleSilentCopyLocal(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
TFE_DEVICE_PLACEMENT_SILENT);
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
TEST(CAPI, TensorHandleSilentCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
TensorHandleSilentCopyLocal(true);
TEST(CAPI, TensorHandleSilentCopyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
void SetAndGetOpDevices(bool async) {
@ -590,6 +586,91 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
void ExecuteAdd(bool async, bool forward_input) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
// If a GPU exists, copy the handle to GPU so that we can exercise
// unprotecting a mirror.
std::string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* n_gpu =
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
TFE_DeleteTensorHandle(n);
n = n_gpu;
}
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
// Store pointer to raw buffer for validation of forwarding behaviour.
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
void* orig_ptr = TF_TensorData(orig);
TF_DeleteTensor(orig);
TFE_Op* add_op = AddOp(ctx, n, m);
std::string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (forward_input) {
TFE_DeleteTensorHandle(n);
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (!forward_input) {
TFE_DeleteTensorHandle(n);
}
TFE_DeleteOp(add_op);
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
if (forward_input || async) {
EXPECT_EQ(orig_ptr, TF_TensorData(t));
} else {
EXPECT_NE(orig_ptr, TF_TensorData(t));
}
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1228,6 +1309,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_DeleteTensorHandle(h_shares_tensor);
}
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
->Attrs()
.FillAttrValueMap(&attr_values);
return attr_values;
}
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1244,8 +1333,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TFE_OpAddInput(minOp, axis, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1284,8 +1372,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
TFE_OpAddInputList(concatOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1325,8 +1412,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
@ -1362,15 +1448,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->inference_ctx);
CHECK(concatOp->operation->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
EXPECT_FALSE(concatOp->operation->OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
EXPECT_EQ(attr_values.find("T"), attr_values.end());
EXPECT_EQ(attr_values.find("N"), attr_values.end());
@ -1457,4 +1543,88 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_ATTR_SHAPE,
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
copy_op->operation.get());
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(copy_op);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpAttrsSerialize) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
serialized_attr_values->length));
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
ASSERT_EQ(tensorflow::DT_INT64,
name_and_attrs.attr().find("dtype")->second.type());
TF_DeleteBuffer(serialized_attr_values);
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
string serialized_dtype;
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
&serialized_dtype));
TFE_OpSetAttrValueProto(
second_var_op, "dtype",
reinterpret_cast<const void*>(serialized_dtype.c_str()),
serialized_dtype.length(), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
second_var_op->operation.get());
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(second_var_op);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
return th;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();

View File

@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);

View File

@ -0,0 +1,398 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* context = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
ASSERT_FALSE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
TFE_OpSetDevice(matmul.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteTensorHandle(hdevice);
TFE_DeleteContext(context);
}
TEST(CUSTOM_DEVICE, ResetOperation) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string(custom_device_name));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpReset(reused_op.get(), "Identity",
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, MakeVariable) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto value_cleaner = tensorflow::gtl::MakeCleanup(
[var_value]() { TFE_DeleteTensorHandle(var_value); });
ASSERT_EQ(tensorflow::string(name),
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,334 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
// Managing context for the DLManagedTensor, will manage the lifetime of
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
// original framework of destruction, and this context will be deleted also.
struct TfDlManagedTensorCtx {
TensorReference reference;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
};
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
return tensor;
}
// Deleter for DLManagedTensor
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
owner->reference.Unref();
delete owner;
}
// Converts TF_DATAType to DLPack data type.
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT:
case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_INT8:
case TF_DataType::TF_INT16:
case TF_DataType::TF_INT32:
case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break;
}
return dtype;
}
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;
int device_id = 0;
if (parsed_name.has_id) {
device_id = parsed_name.id;
}
ctx.device_id = device_id;
if (device_type == "CPU") {
ctx.device_type = DLDeviceType::kDLCPU;
} else if (device_type == "GPU") {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for dlpack");
}
return ctx;
}
// Converts DLContext to TF device name.
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return "CPU:0";
case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id);
default:
return absl::nullopt;
}
}
// Converts DLPack data type to TF_DATATYPE.
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_DataType* tf_dtype) {
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_UINT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_UINT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_UINT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_UINT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_INT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_INT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_INT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_INT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_HALF;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_FLOAT;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_DOUBLE;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
dtype.bits);
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_BFLOAT16;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits: ", dtype.bits);
}
break;
default:
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
dtype.code);
}
}
// Wraps the deleter function of DLManagedTensor to match the function signature
// TFE_NewTensorHandleFromDeviceMemory.
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
}
// Checks whether the stride array matches the layout of compact, row-majored
// data.
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
}
}
return true;
}
} // namespace
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
if (dlMTensor->deleter != nullptr) {
dlMTensor->deleter(dlMTensor);
}
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
DeviceNameFromDlContext(dl_tensor->ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype;
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
if (!s.ok()) {
status->status = std::move(s);
return nullptr;
}
int num_dims = dl_tensor->ndim;
const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data;
size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
}
if (dl_tensor->strides != nullptr &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid strides array from DLPack");
return nullptr;
}
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
return handle;
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
#define TENSORFLOW_C_EAGER_DLPACK_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
// PyCapsule name for DLPack Tensor
const char* const kDlTensorCapsuleName = "dltensor";
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
// void* for further PyCapsule construction.
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
TF_Status* status);
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_DLPACK_H_

View File

@ -0,0 +1,312 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/operation_interface.h"
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
OperationInterface::OperationInterface(TFE_Context* ctx)
: operation_(ctx->context) {}
const string& OperationInterface::DeviceName() const {
absl::variant<Device*, CustomDevice*> variant_device =
(operation_.Device() == kVariantDeviceNull)
? operation_.EagerContext().HostCPU()
: operation_.Device();
return absl::visit([](auto* d) -> const string& { return d->name(); },
variant_device);
}
Status OperationInterface::SetDeviceName(const char* name) {
return operation_.SetDeviceName(name);
}
Status OperationInterface::SetAttrString(const char* attr_name,
const char* data, size_t length) {
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
return Status::OK();
}
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
return Status::OK();
}
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrType(const char* attr_name,
TF_DataType value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
return Status::OK();
}
Status OperationInterface::SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
operation_.MutableAttrs()->Set(attr_name, proto);
return Status::OK();
}
Status OperationInterface::SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name());
OperationInterface* value_operation =
tensorflow::down_cast<OperationInterface*>(value.get());
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
const char* data,
size_t length) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) {
Tensor t;
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
operation_.MutableAttrs()->Set(attr_name, t);
return Status::OK();
}
Status OperationInterface::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
operation_.MutableAttrs()->Set(attr_name, v);
return Status::OK();
}
Status OperationInterface::SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status OperationInterface::SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const DataType>(
reinterpret_cast<const DataType*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) {
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
auto value_operation =
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
funcs[i].set_name(value_operation->operation_.Name());
value_operation->operation_.Attrs().FillAttrValueMap(
funcs[i].mutable_attr());
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
return Status::OK();
}
const OpDef* OperationInterface::GetOpDef(Status* status) {
const tensorflow::OpDef* op_def = operation_.OpDef();
if (op_def) return op_def;
*status = OpDefForOp(Name(), &op_def);
return op_def;
}
Status OperationInterface::InputLength(const char* input_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Input '", input_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::OutputLength(const char* output_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Output '", output_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
return operation_.MaybeInferSingleInputAttrs(h);
}
Status OperationInterface::AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) {
for (auto& input : inputs) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
}
return operation_.InferInputListAttrs(inputs.size());
}
Status OperationInterface::Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
TF_RETURN_IF_ERROR(
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
for (int i = 0; i < *num_retvals; ++i) {
retvals->at(i).reset(
new tensorflow::TensorHandleInterface(handle_retvals[i]));
}
return Status::OK();
}
Status OperationInterface::SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
operation_.SetCancellationManager(
&cancellation_manager->cancellation_manager);
return Status::OK();
}
Status OperationInterface::SetUseXla(bool enable) {
operation_.SetUseXla(enable);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,188 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#include <memory>
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
// Abstract interface to an operation.
class AbstractOperationInterface {
public:
virtual ~AbstractOperationInterface() {}
virtual void Clear() = 0;
virtual tensorflow::Status Reset(const char* op,
const char* raw_device_name) = 0;
virtual const tensorflow::string& Name() const = 0;
virtual const tensorflow::string& DeviceName() const = 0;
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
virtual tensorflow::Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
virtual tensorflow::Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) = 0;
virtual tensorflow::Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual tensorflow::Status SetAttrString(const char* attr_name,
const char* data, size_t length) = 0;
virtual tensorflow::Status SetAttrInt(const char* attr_name,
int64_t value) = 0;
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
float value) = 0;
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual tensorflow::Status SetAttrType(const char* attr_name,
TF_DataType value) = 0;
virtual tensorflow::Status SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) = 0;
virtual tensorflow::Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
const char* value,
size_t length) = 0;
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) = 0;
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) = 0;
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) = 0;
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) = 0;
virtual tensorflow::Status InputLength(const char* input_name,
int* length) = 0;
virtual tensorflow::Status OutputLength(const char* output_name,
int* length) = 0;
// Experimental
virtual tensorflow::Status SetUseXla(bool enable) {
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
}
virtual tensorflow::Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
return tensorflow::errors::Unimplemented(
"SetCancellationManager not implemented");
}
};
namespace tensorflow {
class OpDef;
class OperationInterface : public AbstractOperationInterface {
public:
explicit OperationInterface(TFE_Context* ctx);
~OperationInterface() override{};
void Clear() override { operation_.Clear(); }
Status Reset(const char* op, const char* raw_device_name) override {
return operation_.Reset(op, raw_device_name, false, nullptr);
}
const string& Name() const override { return operation_.Name(); }
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) override;
Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override {
return operation_.OpDef();
};
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, TF_DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
int num_values) override;
Status InputLength(const char* input_name, int* length) override;
Status OutputLength(const char* output_name, int* length) override;
Status SetUseXla(bool enable) override;
Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
private:
const tensorflow::OpDef* GetOpDef(Status* status);
EagerOperation operation_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_

View File

@ -284,7 +284,7 @@ class ForwardAccumulator {
// Temporarily push or pop transient state for this accumulator.
//
// Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators
// temporarily reset its state. Without pushing and popping, accumulators
// ignore operations executed as a direct result of their own jvp
// computations.
void PushState() { call_state_.emplace(nullptr, false); }

View File

@ -55,6 +55,14 @@ class AbstractTensorHandleInterface {
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
// Maintain mirror tensors for any implicit copies to local devices. This
// setting is offered on a per tensor handle basis to avoid potential memory
// over utilization due to holding on to mirrors as well as the original
// tensor. Note this setting overrides the context mirroring policy whereby if
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
// this tensor.
virtual void EnableImplicitMirroring() = 0;
};
namespace tensorflow {
@ -77,6 +85,8 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
AbstractTensorHandleInterface* Copy() override;
void EnableImplicitMirroring() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
TensorHandle* Handle() { return handle_; }

View File

@ -18,37 +18,28 @@ cc_library(
],
)
# Core TensorFlow depends on this, this will be included in main library
cc_library(
name = "filesystem_interface_impl",
srcs = ["filesystem_interface.cc"],
hdrs = ["filesystem_interface.h"],
deps = [
":modular_filesystem",
"//tensorflow/c:tf_file_statistics",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:stringpiece",
],
alwayslink = 1,
)
# Core TensorFlow depends on this, will be included in main library
cc_library(
name = "modular_filesystem",
srcs = ["modular_filesystem.cc"],
hdrs = ["modular_filesystem.h"],
srcs = [
"modular_filesystem.cc",
"modular_filesystem_registration.cc",
],
hdrs = [
"modular_filesystem.h",
"modular_filesystem_registration.h",
],
# TODO(mihaimaruseac): Visibility should be more restrictive once we
# convert to modular filesystems everywhere
visibility = ["//visibility:public"],
deps = [
":filesystem_interface",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
],
)
@ -63,16 +54,12 @@ tf_cc_test(
"notap", # b/139060984, requires implementing modular support for Google filesystem
],
deps = [
":filesystem_interface_impl",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
":modular_filesystem",
"//tensorflow/core:framework_internal",
"//tensorflow/core/lib/io:path",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:error",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:str_util",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:test",
],
)

View File

@ -1,366 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/util/ptr_util.h"
/// This translation unit is linked in core TensorFlow and provides the
/// functionality needed for plugin registration to check ABI/API compatibility,
/// to ensure required methods are present, to ensure plugins are not allowed to
/// change functionality after being loaded and to register the filesystems
/// provided by a plugin. Consult the header file for more information about
/// how this is achieved.
namespace tensorflow {
namespace {
// Checks if the plugin and core ABI numbers match, filling in `status`.
//
// If the numbers don't match, plugin cannot be loaded.
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
TF_Status* status) {
if (pluginABI != coreABI) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded.")
.c_str());
return false;
}
return true;
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
static bool CheckABI(
int plugin_filesystem_ops_ABI,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_ABI,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_ABI,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
"filesystem", status))
return false;
if (plugin_random_access_file_ops != nullptr &&
!CheckABIHelper(plugin_random_access_file_ops_ABI,
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
status))
return false;
if (plugin_writable_file_ops != nullptr &&
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
"writable file", status))
return false;
if (plugin_read_only_memory_region_ops != nullptr &&
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region", status))
return false;
return true;
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void CheckAPI(
int plugin_filesystem_ops_API,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_API,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_API,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_API) {
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
"filesystem");
if (plugin_random_access_file_ops != nullptr)
CheckAPIHelper(plugin_random_access_file_ops_API,
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
if (plugin_writable_file_ops != nullptr)
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (plugin_read_only_memory_region_ops != nullptr)
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
TF_READ_ONLY_MEMORY_REGION_OPS_API,
"read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
if (ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without operations");
return false;
}
if (ops->init == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `init` operation");
return false;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation");
return false;
}
return true;
}
// Validates the random access file operations supplied by the plugin.
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"random access files");
return false;
}
return true;
}
// Validates the writable file operations supplied by the plugin.
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
if (ops == nullptr) {
// We allow read-only filesystems
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"writable files");
return false;
}
return true;
}
// Validates the read only memory region operations given by the plugin.
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// read only memory region support is always optional
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"read only memory regions");
return false;
}
if (ops->data == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `data` operation on "
"read only memory regions");
return false;
}
if (ops->length == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `length` operation on "
"read only memory regions");
return false;
}
return true;
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
// each individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static bool Validate(
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
plugin_random_access_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
return false;
}
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
plugin_writable_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
return false;
}
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
plugin_read_only_memory_region_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return false;
}
return true;
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = sizeof(T);
if (plugin_size < copy_size) {
copy_size = plugin_size;
}
auto core_ops = tensorflow::MakeUnique<T>();
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
return core_ops;
}
} // namespace
} // namespace tensorflow
void RegisterFilesystemPlugin(
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
int plugin_random_access_file_ops_API,
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
int plugin_read_only_memory_region_ops_ABI,
int plugin_read_only_memory_region_ops_API,
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (scheme == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"`scheme` argument must not be `nullptr`.");
return;
}
// ABI numbers must match exactly for plugin to be loaded
if (!tensorflow::CheckABI(
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_ABI, status)) {
return;
}
// API numbers should match but mismatch doesn't block plugin load
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
plugin_random_access_file_ops_API,
plugin_writable_file_ops, plugin_writable_file_ops_API,
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_API);
// Plugin can only be loaded if all supplied ops are valid
if (!tensorflow::Validate(plugin_filesystem_ops,
plugin_random_access_file_ops,
plugin_writable_file_ops,
plugin_read_only_memory_region_ops, status)) {
return;
}
// Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
plugin_filesystem_ops, plugin_filesystem_ops_size);
auto core_random_access_file_ops =
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
plugin_writable_file_ops, plugin_writable_file_ops_size);
auto core_read_only_memory_region_ops =
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_size);
// Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
core_filesystem_ops->init(filesystem.get(), status);
if (!status->status.ok()) {
core_filesystem_ops->cleanup(filesystem.get());
return;
}
// Register new filesystem
status->status = tensorflow::Env::Default()->RegisterFileSystem(
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
}

View File

@ -56,7 +56,7 @@ extern "C" {
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
/// pointed to by the `void*` members is always owned by the plugin. The plugin
/// will provide functions to call to allocate and deallocate this data (see
/// next section) and core TensorFlow ensures to call these at the proper time.
/// next sections) and core TensorFlow ensures to call these at the proper time.
///
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
/// TensorFlow will never touch the `void*` wrapped by these structures, except
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
/// If `statuses` is not null, plugins must fill each element with detailed
/// status for each file, as if calling `path_exists` on each one. Core
/// TensorFlow initializes the `statuses` array and plugins must use
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
/// `TF_SetStatus` to set each element instead of directly assigning.
///
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
/// `path_exists`.
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
///
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// This function will be called by core TensorFlow to clean up all path
/// arguments for all other methods in the filesystem API.
///
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all children were returned.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
/// different than `TF_OK`, free any memory that might have been allocated for
/// `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all matches were returned.
/// * Might use any other error value for `status` to signal other errors.
@ -736,95 +748,132 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// SECTION 4. Plugin registration and initialization
/// ----------------------------------------------------------------------------
///
/// In this section we define two functions:
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
/// be called by core TensorFlow when the filesystem plugin is loaded;
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
/// In this section we define the API used by core TensorFlow to initialize a
/// filesystem provided by a plugin. That is, we define the following:
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
/// it will be called by core TensorFlow when the filesystem plugin is
/// loaded;
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
/// plugins and core TensorFlow about the operations provided and metadata;
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
/// collects information about all the file schemes that the plugin provides
/// support for, as well as about the plugin's memory handling routines;
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
/// their `TF_InitPlugin` to record the versioning information the plugins
/// are compiled against.
///
/// The `TF_InitPlugin` function is used by plugins to set up the data
/// structures that implement this interface, as presented in Section 2.
///
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
/// 2. If the API numbers are mismatched, we warn the user and continue
/// loading the plugin.
/// 3. If any required operation is missing, we stop loading the plugin.
///
/// If all these checks succeed, we copy the plugin operations to a different
/// memory location so that core TensorFlow has the guarantee that they won't be
/// changed by plugins at a later time. Finally, we initialize the opaque
/// pointer of `TF_Filesystem` by calling the required `init` function of
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
/// structures that implement this interface, as presented in Section 2. In
/// order to not have plugin shared objects call back symbols defined in core
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
/// metadata and setting up all the supported operations and the URI schemes
/// that are supported).
// Initializes a TensorFlow plugin.
//
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
//
// Filesystem plugins can be loaded on demand by users via
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
// paths (although this has a security risk if two plugins register for the
// same filesystem and the malicious one loads before the legimitate one -
// but we consider this to be something that users should care about and
// manage themselves). In both of these cases, core TensorFlow looks for
// the `TF_InitPlugin` symbol and calls that function.
//
// A plugin is loaded only if this `status` is `TF_OK` after the call.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
/// This structure incorporates the operations defined in Section 2 and the
/// metadata defined in section 3, allowing plugins to define different ops
/// for different URI schemes.
///
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
/// must be "". The scheme must never be `nullptr`.
///
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
/// TensorFlow uses the information present in this to initialize filesystems
/// for the URI schemes that the plugin requests.
///
/// All pointers defined in this structure point to memory allocated by the DSO
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that a new type of file needs to be
/// supported, add the new ops and metadata at the end of the structure.
typedef struct TF_FilesystemPluginOps {
char* scheme;
int filesystem_ops_abi;
int filesystem_ops_api;
size_t filesystem_ops_size;
TF_FilesystemOps* filesystem_ops;
int random_access_file_ops_abi;
int random_access_file_ops_api;
size_t random_access_file_ops_size;
TF_RandomAccessFileOps* random_access_file_ops;
int writable_file_ops_abi;
int writable_file_ops_api;
size_t writable_file_ops_size;
TF_WritableFileOps* writable_file_ops;
int read_only_memory_region_ops_abi;
int read_only_memory_region_ops_api;
size_t read_only_memory_region_ops_size;
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
} TF_FilesystemPluginOps;
/// Registers a filesystem plugin so that core TensorFlow can use it.
/// This structure gathers together all the operations provided by the plugin.
///
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
///
/// Arguments (grouped by category):
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
/// * `..API`: API compatibility numbers (see Section 3.).
/// * `..Size`: Sizes of the operation tables (see Section 3.).
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
/// must be "". Must never be `nullptr`.
/// * `..Ops`: The function tables provided by the plugin. Owned by the
/// plugin, but core TensorFlow makes a copy of these.
/// * `status`: The output variable for representing success/failure.
/// Since memory that is allocated by the DSO gets transferred to core
/// TensorFlow, we need to provide a way for the allocation and deallocation to
/// match. This is why this structure also defines `plugin_memory_allocate` and
/// `plugin_memory_free` members.
///
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
/// `status` means that plugin failed to load properly and as such the
/// operations it provides cannot be used at all (i.e., core TensorFlow will
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
/// values).
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
int pluginReadOnlyMemoryRegionOpsAPI,
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
const TF_FilesystemOps* pluginFilesystemOps,
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
const TF_WritableFileOps* pluginWritableFileOps,
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
TF_Status* status);
/// All memory allocated by the plugin that will be owned by core TensorFlow
/// must be allocated using the allocator in this structure. Core TensorFlow
/// will use the deallocator to free this memory once it no longer needs it.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that new global operations must be
/// provided, add them at the end of the structure.
typedef struct TF_FilesystemPluginInfo {
size_t num_schemes;
TF_FilesystemPluginOps* ops;
void* (*plugin_memory_allocate)(size_t size);
void (*plugin_memory_free)(void* ptr);
} TF_FilesystemPluginInfo;
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
/// Plugins should prefer using this macro instead of a direct call.
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
RegisterFilesystemPlugin( \
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
pluginRandomAccessFileOps, pluginWritableFileOps, \
pluginReadOnlyMemoryRegionOps, status)
/// Convenience function for setting the versioning metadata.
///
/// The argument is guaranteed to not be `nullptr`.
///
/// We want this to be defined in the plugin's memory space and we guarantee
/// that core TensorFlow will never call this.
static inline void TF_SetFilesystemVersionMetadata(
TF_FilesystemPluginOps* ops) {
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
}
/// Initializes a TensorFlow plugin.
///
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
///
/// Filesystem plugins can be loaded on demand by users via
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
/// paths (although this has a security risk if two plugins register for the
/// same filesystem and the malicious one loads before the legimitate one -
/// but we consider this to be something that users should care about and
/// manage themselves). In both of these cases, core TensorFlow looks for
/// the `TF_InitPlugin` symbol and calls this function.
///
/// For every filesystem URI scheme that this plugin supports, the plugin must
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
/// `TF_SetFilesystemVersionMetadata` for that entry.
///
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
/// freed in a compatible way.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
#ifdef __cplusplus
} // end extern "C"

View File

@ -18,11 +18,10 @@ limitations under the License.
#include <string>
#include <utility>
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/util/ptr_util.h"
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
@ -165,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dir);
char** children;
// Note that `children` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** children = nullptr;
const int num_children =
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
plugin_status.get());
if (num_children >= 0) {
for (int i = 0; i < num_children; i++) {
result->push_back(std::string(children[i]));
free(children[i]);
plugin_memory_free_(children[i]);
}
free(children);
plugin_memory_free_(children);
}
return StatusFromTF_Status(plugin_status.get());
@ -186,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
char** matches;
// Note that `matches` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** matches = nullptr;
const int num_matches = ops_->get_matching_paths(
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
if (num_matches >= 0) {
for (int i = 0; i < num_matches; i++) {
result->push_back(std::string(matches[i]));
free(matches[i]);
plugin_memory_free_(matches[i]);
}
free(matches);
plugin_memory_free_(matches);
}
return StatusFromTF_Status(plugin_status.get());
@ -358,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
std::string ret(p);
free(p);
// Since `p` is allocated by plugin, free it using plugin's method.
plugin_memory_free_(p);
return ret;
}
@ -435,4 +439,26 @@ Status ModularWritableFile::Tell(int64* position) {
return StatusFromTF_Status(plugin_status.get());
}
Status RegisterFilesystemPlugin(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Do the actual registration
return filesystem_registration::RegisterFilesystemPluginImpl(&info);
}
} // namespace tensorflow

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// TODO(b/143949615): After all filesystems are converted, this file will be
// moved to core/platform, and this class can become a singleton and replace the
// need for `Env::Default()`. At that time, we might decide to remove the need
// for `Env::Default()` altoghether, but that's a different project, not in
// for `Env::Default()` altogether, but that's a different project, not in
// scope for now. I'm just mentioning this here as that transition will mean
// removal of the registration part from `Env` and adding it here instead: we
// will need tables to hold for each scheme the function tables that implement
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops)
read_only_memory_region_ops,
std::function<void*(size_t)> plugin_memory_allocate,
std::function<void(void*)> plugin_memory_free)
: filesystem_(std::move(filesystem)),
ops_(std::move(filesystem_ops)),
random_access_file_ops_(std::move(random_access_file_ops)),
writable_file_ops_(std::move(writable_file_ops)),
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
plugin_memory_free_(std::move(plugin_memory_free)) {}
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops_;
std::function<void*(size_t)> plugin_memory_allocate_;
std::function<void(void*)> plugin_memory_free_;
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
};
@ -156,6 +162,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
};
// Registers a filesystem plugin so that core TensorFlow can use it.
Status RegisterFilesystemPlugin(const std::string& dso_path);
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_

View File

@ -0,0 +1,327 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
// Checks that all schemes provided by a plugin are valid.
// TODO(mihaimaruseac): More validation could be done here, based on supported
// charset, maximum length, etc. Punting it for later.
static Status ValidateScheme(const char* scheme) {
if (scheme == nullptr)
return errors::InvalidArgument(
"Attempted to register filesystem with `nullptr` URI scheme");
return Status::OK();
}
// Checks if the plugin and core ABI numbers match.
//
// If the numbers don't match, plugin cannot be loaded.
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
if (pluginABI != coreABI)
return errors::FailedPrecondition(
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded."));
return Status::OK();
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABI(int, int, StringPiece)`.
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
if (ops->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
TF_RANDOM_ACCESS_FILE_OPS_ABI,
"random access file"));
if (ops->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
if (ops->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region"));
return Status::OK();
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
if (ops->random_access_file_ops != nullptr)
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
"random access file");
if (ops->writable_file_ops != nullptr)
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (ops->read_only_memory_region_ops != nullptr)
CheckAPI(ops->read_only_memory_region_ops_api,
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static Status ValidateHelper(const TF_FilesystemOps* ops) {
if (ops == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without operations");
if (ops->init == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `init` operation");
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation");
return Status::OK();
}
// Validates the random access file operations supplied by the plugin.
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on random "
"access files");
return Status::OK();
}
// Validates the writable file operations supplied by the plugin.
static Status ValidateHelper(const TF_WritableFileOps* ops) {
if (ops == nullptr) {
// We allow read-only filesystems
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on writable "
"files");
return Status::OK();
}
// Validates the read only memory region operations given by the plugin.
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
if (ops == nullptr) {
// read only memory region support is always optional
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on read "
"only memory regions");
if (ops->data == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `data` operation on read only "
"memory regions");
if (ops->length == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `length` operation on read only "
"memory regions");
return Status::OK();
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
// individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
if (ops->filesystem_ops->new_random_access_file != nullptr &&
ops->random_access_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
if ((ops->filesystem_ops->new_writable_file != nullptr ||
ops->filesystem_ops->new_appendable_file != nullptr) &&
ops->writable_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
ops->read_only_memory_region_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return Status::OK();
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = std::min(plugin_size, sizeof(T));
auto core_ops = tensorflow::MakeUnique<T>();
memset(core_ops.get(), 0, sizeof(T));
memcpy(core_ops.get(), plugin_ops, copy_size);
return core_ops;
}
// Registers one filesystem from the plugin.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
int index) {
// Step 1: Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
info->ops[index].random_access_file_ops,
info->ops[index].random_access_file_ops_size);
auto core_writable_file_ops =
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
info->ops[index].writable_file_ops_size);
auto core_read_only_memory_region_ops =
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
info->ops[index].read_only_memory_region_ops,
info->ops[index].read_only_memory_region_ops_size);
// Step 2: Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
TF_Status* c_status = TF_NewStatus();
Status status = Status::OK();
core_filesystem_ops->init(filesystem.get(), c_status);
status = Status(c_status->status);
TF_DeleteStatus(c_status);
if (!status.ok()) return status;
// Step 3: Actual registration
return Env::Default()->RegisterFileSystem(
info->ops[index].scheme,
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops),
info->plugin_memory_allocate, info->plugin_memory_free));
}
// Registers filesystem at `index`, if plugin is providing valid information.
//
// Extracted to a separate function so that pointers inside `info` are freed
// by the caller regardless of whether validation/registration failed or not.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status ValidateAndRegisterFilesystems(
const TF_FilesystemPluginInfo* info, int index) {
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
return Status::OK();
}
// Ensures that the plugin provides the required memory management operations.
static Status ValidatePluginMemoryRoutines(
const TF_FilesystemPluginInfo* info) {
if (info->plugin_memory_allocate == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_allocate`");
if (info->plugin_memory_free == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_free`");
return Status::OK();
}
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(info));
// Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < info->num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(info, i));
info->plugin_memory_free(info->ops[i].scheme);
info->plugin_memory_free(info->ops[i].filesystem_ops);
info->plugin_memory_free(info->ops[i].random_access_file_ops);
info->plugin_memory_free(info->ops[i].writable_file_ops);
info->plugin_memory_free(info->ops[i].read_only_memory_region_ops);
}
info->plugin_memory_free(info->ops);
return status;
}
} // namespace filesystem_registration
} // namespace tensorflow

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace filesystem_registration {
// Implementation for filesystem registration
//
// Don't call this directly. Instead call `RegisterFilesystemPlugin`.
// Exposed only for static registration of local filesystems.
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info);
} // namespace filesystem_registration
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_

View File

@ -19,6 +19,7 @@ tf_cc_shared_object(
cc_library(
name = "posix_filesystem_impl",
srcs = ["posix_filesystem.cc"],
hdrs = ["posix_filesystem.h"],
deps = [
":posix_filesystem_helper",
"//tensorflow/c:tf_status",
@ -26,6 +27,20 @@ cc_library(
],
)
# Since building pip package and API tests require a filesystem, we provide a
# static registration target that they should link against.
cc_library(
name = "posix_filesystem_static",
srcs = ["posix_filesystem_static.cc"],
visibility = ["//visibility:public"],
deps = [
":posix_filesystem_impl",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/filesystem:modular_filesystem",
],
alwayslink = 1,
)
# Library implementing helper functionality, so that the above only contains
# the API implementation for modular filesystems.
cc_library(

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
#include <dirent.h>
#include <errno.h>
#include <fcntl.h>
@ -24,15 +26,15 @@ limitations under the License.
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
@ -45,7 +47,9 @@ typedef struct PosixFile {
static void Cleanup(TF_RandomAccessFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
close(posix_file->fd);
free(const_cast<char*>(posix_file->filename));
// This would be safe to free using `free` directly as it is only opaque.
// However, it is better to be consistent everywhere.
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -100,7 +104,7 @@ typedef struct PosixFile {
static void Cleanup(TF_WritableFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
free(const_cast<char*>(posix_file->filename));
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -383,12 +387,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
if (num_entries < 0) {
TF_SetStatusFromIOError(status, errno, path);
} else {
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(dir_entries[i]->d_name);
free(dir_entries[i]);
plugin_memory_free(dir_entries[i]);
}
free(dir_entries);
plugin_memory_free(dir_entries);
}
return num_entries;
@ -396,48 +401,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
} // namespace tf_posix_filesystem
void TF_InitPlugin(TF_Status* status) {
TF_RandomAccessFileOps random_access_file_ops = {
tf_random_access_file::Cleanup,
tf_random_access_file::Read,
};
TF_WritableFileOps writable_file_ops = {
tf_writable_file::Cleanup, tf_writable_file::Append,
tf_writable_file::Tell, tf_writable_file::Flush,
tf_writable_file::Sync, tf_writable_file::Close,
};
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
tf_read_only_memory_region::Cleanup,
tf_read_only_memory_region::Data,
tf_read_only_memory_region::Length,
};
TF_FilesystemOps filesystem_ops = {
tf_posix_filesystem::Init,
tf_posix_filesystem::Cleanup,
tf_posix_filesystem::NewRandomAccessFile,
tf_posix_filesystem::NewWritableFile,
tf_posix_filesystem::NewAppendableFile,
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
tf_posix_filesystem::CreateDir,
/*recursively_create_dir=*/nullptr,
tf_posix_filesystem::DeleteFile,
tf_posix_filesystem::DeleteDir,
/*delete_recursively=*/nullptr,
tf_posix_filesystem::RenameFile,
tf_posix_filesystem::CopyFile,
tf_posix_filesystem::PathExists,
/*paths_exist=*/nullptr,
tf_posix_filesystem::Stat,
/*is_directory=*/nullptr,
/*get_file_size=*/nullptr,
/*translate_name=*/nullptr,
tf_posix_filesystem::GetChildren,
/*get_matching_paths=*/nullptr,
/*flush_caches=*/nullptr,
};
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
for (const char* scheme : {"", "file"})
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
&random_access_file_ops, &writable_file_ops,
&read_only_memory_region_ops, status);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_posix_filesystem::Init;
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
// Initialize the POSIX filesystem.
//
// In general, the `TF_InitPlugin` symbol doesn't need to be exposed in a header
// file, since the plugin registration will look for the symbol in the DSO file
// that provides the filesystem functionality. However, the POSIX filesystem
// needs to be statically registered in some tests and utilities for building
// the API files at the time of creating the pip package. Hence, we need to
// expose this function so that this filesystem can be statically registered
// when needed.
void TF_InitPlugin(TF_FilesystemPluginInfo* info);
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_

View File

@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
}
// Both files have been opened, do the transfer.
// Since errno would be overriden by `close` below, save it here.
// Since errno would be overridden by `close` below, save it here.
int error_code = 0;
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
namespace tensorflow {
// Register the POSIX filesystems statically.
// Return value will be unused
bool StaticallyRegisterLocalFilesystems() {
TF_FilesystemPluginInfo info;
TF_InitPlugin(&info);
Status status = filesystem_registration::RegisterFilesystemPluginImpl(&info);
if (!status.ok()) {
VLOG(0) << "Static POSIX filesystem could not be registered: " << status;
return false;
}
return true;
}
// Perform the actual registration
static bool unused = StaticallyRegisterLocalFilesystems();
} // namespace tensorflow

View File

@ -0,0 +1,36 @@
# Experimental windows filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for Windows environment
tf_cc_shared_object(
name = "windows_filesystem.dll",
framework_so = [],
linkstatic = False,
tags = [
"manual",
"nobuilder",
"notap",
],
visibility = ["//visibility:public"],
deps = [":windows_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "windows_filesystem_impl",
srcs = ["windows_filesystem.cc"],
copts = get_win_copts(),
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(mihaimaruseac): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_windows_filesystem {
// TODO(mihaimaruseac): Implement later
} // namespace tf_windows_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -27,14 +27,10 @@ namespace {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
params.device = &dummy_device;
params.op_kernel = kernel.get();
gtl::InlinedVector<TensorValue, 4> inputs;

View File

@ -18,19 +18,40 @@ limitations under the License.
#include "tensorflow/c/kernels.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include <utility>
#include "absl/container/inlined_vector.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
struct MyCustomKernel {
bool created;
@ -134,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
TEST(TestKernel, TestInputAndOutputCount) {
@ -202,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
{
OpKernelContext::Params p;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
p.device = &dummy_device;
p.step_id = 43;

View File

@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") {
if (op.name() == "AttributeAccessorsOp") {
ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input());

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include <memory>
#include <vector>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
}
} // namespace tensorflow
namespace {
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
const int64_t* dims, int num_dims, size_t len) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
}
} // namespace
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
tensorflow::cpu_allocator());
return TF_NewTensor(dtype, dims, num_dims, data, len,
tensorflow::deallocate_buffer,
tensorflow::cpu_allocator());
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
tensorflow::cpu_allocator(), /*owns_memory=*/true);
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
TF_ManagedBuffer* buf = nullptr;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
// Other types have the same representation, so copy only if it is safe to
// do so.
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
len, tensorflow::deallocate_buffer, nullptr);
len, tensorflow::deallocate_buffer, nullptr,
/*owns_memory=*/true);
std::memcpy(buf->data(), data, len);
// Free the original buffer.
deallocator(data, len, deallocator_arg);
} else {
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
@ -383,7 +393,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(Data()), ByteSize()))) {
return InvalidArgument(
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
}
return Status::OK();
}

View File

@ -58,9 +58,9 @@ extern "C" {
// start_offset: array[uint64]
// data: byte[...]
//
// The string length (as a varint), followed by the contents of the string
// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode
// facilitate this encoding.
// The string length (as a varint, start_offset[i + 1] - start_offset[i]),
// followed by the contents of the string is encoded at data[start_offset[i]].
// TF_StringEncode and TF_StringDecode facilitate this encoding.
typedef struct TF_Tensor TF_Tensor;

View File

@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
public:
TF_ManagedBuffer(void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg)
void* deallocator_arg, bool owns_memory)
: TensorBuffer(data),
len_(len),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg) {}
deallocator_arg_(deallocator_arg),
owns_memory_(owns_memory) {}
~TF_ManagedBuffer() override {
(*deallocator_)(data(), len_, deallocator_arg_);
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
// Prevents input forwarding from mutating this buffer.
bool OwnsMemory() const override { return false; }
bool OwnsMemory() const override { return owns_memory_; }
private:
const size_t len_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
bool owns_memory_;
};
namespace tensorflow {

View File

@ -41,6 +41,16 @@ filegroup(
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"training/coordinator.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "gradients",
srcs = [
@ -622,6 +632,7 @@ tf_gen_op_wrappers_cc(
"tpu_configuration_ops",
"tpu_cross_replica_ops",
"tpu_embedding_ops",
"tpu_embedding_load_retrieve_ops",
"tpu_functional_ops",
"tpu_heartbeat_ops",
"tpu_host_compute_ops",

View File

@ -41,7 +41,7 @@ class ClientSession::Impl {
std::shared_ptr<Graph> graph_;
mutable mutex mu_;
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
};
ClientSession::ClientSession(const Scope& scope, const string& target)

View File

@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
// Used to identify nodes at which to stop backprop.
std::unordered_set<int> GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes);
const std::unordered_set<int>& output_nodes) const;
const Scope& scope_;
const ops::GradOpRegistry* registry_;
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes) {
const std::unordered_set<int>& output_nodes) const {
// Output nodes that get transitively consumed by other `outputs_` are stored
// in `internal_outputs`.
std::unordered_set<int> internal_outputs;
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
"Unable to find backprop list for node.id ", src.node()->name());
}
const auto& grads = iter->second;
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backproped gradients that remain after filtering,
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backpropped gradients that remain after filtering,
// or 'NoGradient' otherwise.
std::vector<Output> grads_to_keep;
for (const Output& o : grads) {
@ -519,17 +519,17 @@ Status SymbolicGradientBuilder::AddGradients() {
// Backprop along the in edges.
// TODO(andydavis) Find cleaner way to map each grad output returned by
// gradient function to the src node/output to which it should be
// backproped. Maybe grad functions can return a vector of Output pairs to
// backpropped. Maybe grad functions can return a vector of Output pairs to
// make this association explicit.
size_t dx_index = 0;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
if (dx_index == dx.size()) {
int dx_index = e->dst_input();
if (dx_index >= dx.size()) {
return errors::Internal(
"Invalid gradient output index: ", dx_index, " size: ", dx.size());
}
TF_RETURN_IF_ERROR(
BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()}));
BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()}));
}
}

View File

@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
}
TEST_F(GradientsTest, AddSymbolicGradientsTest) {
Scope scope = Scope::NewRootScope();
for (int cnt = 0; cnt < 100; ++cnt) {
int N = 5 + rand() % 10;
// Construct forward graph.
OutputList inputs;
for (int i = 0; i < N; ++i) {
auto a = Const(scope, i, {1});
inputs.push_back(a);
}
auto pack = Stack(scope, inputs);
TF_ASSERT_OK(scope.status());
// Construct grad inputs.
OutputList output_grads;
Tensor ts(DT_INT32, {N, 1});
auto v = ts.matrix<int32>();
for (int i = 0; i < N; ++i) {
v(i, 0) = i;
}
auto dy = Const(scope, ts);
output_grads.push_back(dy);
// Call AddSymbolicGradients.
std::vector<Output> grad_outputs;
TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs,
output_grads, &grad_outputs));
ClientSession session((scope));
std::vector<Tensor> in_grad;
TF_ASSERT_OK(session.Run(grad_outputs, &in_grad));
for (int i = 0; i < N; ++i) {
test::ExpectTensorEqual<int>(in_grad[i], test::AsTensor<int>({i}, {1}));
}
}
}
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
// a single nodes output.

View File

@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
// Multiply after broadcasting vec to match dimensions of mat.
// Args:
// vec: A 1-D tensor of dimension [D0]
// mat: A 2-D tensor of dimesnion [D0, D1]
// mat: A 2-D tensor of dimension [D0, D1]
//
// Returns:
// A tensor of dimension [D0, D1], the result fo vec * mat.

View File

@ -68,6 +68,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:resource_loader",
],
)
@ -224,3 +225,15 @@ filegroup(
"testdata/VarsAndArithmeticObjectGraph/**",
]),
)
exports_files(
glob([
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
"testdata/half_plus_two_v2/**",
"testdata/x_plus_y_v2_debuginfo/**",
"testdata/CyclicModule/**",
"testdata/VarsAndArithmeticObjectGraph/**",
]),
)

View File

@ -21,15 +21,22 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
string TestDataPbTxt() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two_pbtxt", "00000123");
}
string TestDataSharded() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123");
}
class ReaderTest : public ::testing::Test {
protected:
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
TEST_F(ReaderTest, TagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
TEST_F(ReaderTest, NoTagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
&meta_graph_def);
EXPECT_FALSE(st.ok());
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
TEST_F(ReaderTest, NoTagMatchMultiple) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
EXPECT_FALSE(st.ok());
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
TEST_F(ReaderTest, PbtxtFormat) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
TEST_F(ReaderTest, InvalidExportPath) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
const string export_dir = GetDataDependencyFilepath("missing-path");
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def);
EXPECT_FALSE(st.ok());

View File

@ -114,14 +114,14 @@ class Coordinator {
condition_variable wait_for_stop_;
mutex mu_;
bool should_stop_ GUARDED_BY(mu_);
bool should_stop_ TF_GUARDED_BY(mu_);
mutex status_lock_;
Status status_ GUARDED_BY(status_lock_);
Status status_ TF_GUARDED_BY(status_lock_);
mutable mutex runners_lock_;
std::vector<std::unique_ptr<RunnerInterface>> runners_
GUARDED_BY(runners_lock_);
TF_GUARDED_BY(runners_lock_);
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
};

View File

@ -119,8 +119,8 @@ class QueueRunner : public RunnerInterface {
std::unique_ptr<thread::ThreadPool> thread_pool_;
mutex mu_;
int runs_ = 0;
Status status_ GUARDED_BY(mu_);
Status enqueue_status_ GUARDED_BY(mu_);
Status status_ TF_GUARDED_BY(mu_);
Status enqueue_status_ TF_GUARDED_BY(mu_);
std::unique_ptr<BlockingCounter> counter_;
Coordinator* coord_;
@ -131,7 +131,7 @@ class QueueRunner : public RunnerInterface {
std::vector<std::function<void(Status)>> callbacks_;
mutable std::unique_ptr<mutex> cg_mu_;
std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_);
RunOptions run_options_;
};

View File

@ -20,9 +20,11 @@ from __future__ import print_function as _print_function
import logging as _logging
import os as _os
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# pylint: disable=g-bad-import-order
@ -36,20 +38,19 @@ try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
# Make sure we get the correct summary module with lazy loading
setattr(_current_module, "summary", summary)
except ImportError:
_logging.warning(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
"installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from tensorflow.python.keras.api._v2 import keras
@ -59,6 +60,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
#

View File

@ -20,8 +20,10 @@ from __future__ import print_function as _print_function
import os as _os
import sys as _sys
import six as _six
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# pylint: disable=g-bad-import-order
@ -31,13 +33,14 @@ from tensorflow.python.tools import module_util as _module_util
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from tensorflow.python.keras.api._v1 import keras
@ -47,6 +50,14 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
_current_module.app.flags = flags # pylint: disable=undefined-variable

View File

@ -1,11 +1,5 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
package(
@ -39,9 +33,11 @@ cc_library(
deps = [
":aot_only_var_handle_op",
":embedded_protocol_buffers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -69,36 +65,7 @@ cc_library(
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
# Necessary for the pywrap inclusion below.
tf_pybind_cc_library_wrapper(
name = "tfcompile_headers_lib",
deps = [
":tfcompile_lib",
],
)
tf_python_pybind_extension(
name = "_pywrap_tfcompile",
srcs = ["tfcompile_wrapper.cc"],
features = ["-layering_check"],
module_name = "_pywrap_tfcompile",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
":tfcompile_headers_lib",
"@pybind11",
"//third_party/python_runtime:headers",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
# These headers cannot be brought in via cc_header_only_library
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
@ -119,6 +86,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
@ -131,6 +99,19 @@ tf_cc_binary(
deps = [":tfcompile_main"],
)
cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
cc_library(
name = "tfcompile_main",
srcs = ["tfcompile_main.cc"],
@ -156,54 +137,108 @@ cc_library(
# tfcompile.bzl correctly handles usage from outside of the package that it is
# defined in.
# A simple test of tf_library from a text protobuf, mostly to enable the
# benchmark_test.
# A simple test of tf_library from a text protobuf, to enable benchmark_test.
# This test uses an incompleted graph with a node that is not defined. The
# compilation works because the undefined node is a feed node.
tf_library(
name = "test_graph_tfadd",
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
graph = "test_graph_tfadd.pbtxt",
mlir_components = "None",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfadd_mlir_bridge",
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
graph = "test_graph_tfadd.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the unknown op is not needed for the fetches.
# the compilation works because the node with the unknown op is not needed
# for the fetches.
tf_library(
name = "test_graph_tfunknownop",
testonly = 1,
config = "test_graph_tfunknownop.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "None",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfunknownop_mlir_bridge",
testonly = 1,
config = "test_graph_tfunknownop.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the op between the unknown op and the
# fetches is a feed.
# the compilation works because the node with the unknown op is only used as
# an input of a feed node.
tf_library(
name = "test_graph_tfunknownop2",
testonly = 1,
config = "test_graph_tfunknownop2.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "None",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfunknownop2_mlir_bridge",
testonly = 1,
config = "test_graph_tfunknownop2.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the unknown op is fed.
# the compilation works because the node with the unknown op is a feed node.
tf_library(
name = "test_graph_tfunknownop3",
testonly = 1,
config = "test_graph_tfunknownop3.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "None",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfunknownop3_mlir_bridge",
testonly = 1,
config = "test_graph_tfunknownop3.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
@ -283,9 +318,13 @@ test_suite(
tests = [
":benchmark_test",
":codegen_test",
":test_graph_tfadd_mlir_bridge_test",
":test_graph_tfadd_test",
":test_graph_tfunknownop2_mlir_bridge_test",
":test_graph_tfunknownop2_test",
":test_graph_tfunknownop3_mlir_bridge_test",
":test_graph_tfunknownop3_test",
":test_graph_tfunknownop_mlir_bridge_test",
":test_graph_tfunknownop_test",
"//tensorflow/compiler/aot/tests:all_tests",
],

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@ -169,7 +170,9 @@ Status GenArgMethods(const tf2xla::Config& config,
const xla::ProgramShapeProto& ps,
const CompileResult& compile_result, string* methods) {
size_t num_args = ps.parameters_size();
if (config.feed_size() + config.variable_size() != num_args) {
// feed_size() + variable_size() is the maximum number of args as an
// implementation may not create an argument for an unused variable.
if (config.feed_size() + config.variable_size() < num_args) {
return errors::InvalidArgument(
"mismatch between feed_size(", config.feed_size(), ")+variable_size(",
config.variable_size(), ") and num_args(", num_args, ")");
@ -288,8 +291,8 @@ Status GenVariableMethods(const tf2xla::Config& config,
}
// Generates code implementing {Arg,Result}Names(), where T is one of
// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
// literal in the array, with nullptr terminating the array.
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
// string literal in the array, with nullptr terminating the array.
template <typename T>
string GenNameToIndexCode(const T& entries, bool generate) {
// No need for a static array if we're not supposed to generate the data.
@ -419,6 +422,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// Generate metadata.
const string arg_names_code =
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
auto variable_copy = config.variable();
for (auto& var : variable_copy) {
if (var.name().empty()) {
var.set_name(var.node_name());
}
}
const string variable_names_code =
GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
const string result_names_code =
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto =
@ -457,8 +470,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
{{INCLUDE_XLA_DATA_PROTO}}
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "{{TF_HEADER_ROOT}}/core/platform/types.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; }
namespace xla { class ExecutableRunOptions; }
@ -507,6 +520,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = {{ARG_NUM}};
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries.
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
@ -522,8 +538,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
@ -626,6 +644,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
// Array of names of each positional variable, terminated by nullptr.
static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
// Array of names of each positional result, terminated by nullptr.
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
@ -654,12 +675,12 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
@ -674,6 +695,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include <algorithm>
#include <string>
#include <vector>
@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -139,14 +141,23 @@ TEST_F(ParseCppClassTest, ParseFail) {
static void CompareWithGoldenFile(
const string& tensorflow_relative_golden_file_name,
const string& expected_contents) {
const string& expected_contents, bool ignore_cr) {
// Get rid of all CR characters, we may be running under windows.
string sanitized_expected_contents(expected_contents);
if (ignore_cr) {
sanitized_expected_contents.erase(
std::remove(sanitized_expected_contents.begin(),
sanitized_expected_contents.end(), '\r'),
sanitized_expected_contents.end());
}
// To update the golden file, flip update_golden to true and run the
// following:
// bazel test --test_strategy=local \
// third_party/tensorflow/compiler/aot:codegen_test
// "third_party/tensorflow/compiler/aot:codegen_test"
const bool update_golden = false;
const string golden_file_name = io::JoinPath(
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
string golden_file_name =
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
if (update_golden) {
TF_EXPECT_OK(
@ -156,6 +167,11 @@ static void CompareWithGoldenFile(
string golden_file_contents;
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
&golden_file_contents));
if (ignore_cr) {
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
golden_file_contents.end(), '\r'),
golden_file_contents.end());
}
EXPECT_EQ(golden_file_contents, expected_contents);
}
@ -197,15 +213,20 @@ TEST(CodegenTest, Golden) {
variable3->mutable_shape()->add_dim()->set_size(5);
variable3->set_type(DT_INT32);
CompileResult compile_result;
compile_result.tensorflow_header_root = "third_party/tensorflow";
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
{},
{BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
BufferInfo::MakeTempBuffer(2),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
11, {}));
compile_result.program_shape =
xla::ShapeUtil::MakeProgramShape(
{
@ -230,14 +251,18 @@ TEST(CodegenTest, Golden) {
// The other fields in metadata_result are tested as part of the generated
// header test.
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data);
// This specific golden test checks a binary file. It can potentially run into
// issues due to ABIs not being stable, but has not so far.
// If we see any ABI issues, we should reconsider this specific test case.
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data, false);
string header;
TF_ASSERT_OK(
GenerateHeader(opts, config, compile_result, metadata_result, &header));
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
true);
}
} // namespace
} // namespace tfcompile

View File

@ -55,14 +55,17 @@ namespace bar {
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
//
// Memory stats:
// arg bytes total: 104
// arg bytes aligned: 192
// arg bytes total: 392
// arg bytes aligned: 576
// temp bytes total: 126
// temp bytes aligned: 320
// temp bytes aligned: 512
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = 2;
static constexpr size_t kNumArgs = 5;
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = 3;
// Byte size of each argument buffer. There are kNumArgs entries.
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
private:
// Number of buffers for the compiled computation.
static constexpr size_t kNumBuffers = 6;
static constexpr size_t kNumBuffers = 12;
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::xla::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
};
return kBufferInfos;
@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
1, 3
1, 3, 5, 7, 9
};
return kArgIndexToBufferIndex;
}
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 5;
static constexpr size_t kResultIndex = 11;
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {
@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return kNames;
}
// Array of names of each positional variable, terminated by nullptr.
static const char** StaticVariableNames() {
static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
return kNames;
}
// Array of names of each positional result, terminated by nullptr.
static const char** StaticResultNames() {
static const char* kNames[] = {"myfetch", nullptr};

View File

@ -20,9 +20,11 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -38,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -85,7 +88,6 @@ Status CompileXla(xla::CompileOnlyClient* client,
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
compile_result->pointer_size =
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK();
@ -105,14 +107,17 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
.ValueOrDie();
xla::XlaComputation computation;
if (flags.mlir_components == "Bridge") {
TF_RETURN_IF_ERROR(
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
} else {
if (!flags.mlir_components.empty()) {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
graph_def, config, &computation, flags.debug_info,
flags.debug_info_path_begin_marker));
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
if (flags.quantize) {
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
@ -130,8 +135,7 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic,
flags.tensorflow_header_root);
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
return CompileXla(client, computation, aot_opts, compile_result);
}
@ -144,7 +148,7 @@ static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
}
}
static std::once_flag targets_init;
static absl::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
@ -168,8 +172,25 @@ static void InitializeTargets() {
LLVMInitializeX86AsmPrinter();
}
// Replaces {{tag.type tag.name}} in the error message with tag_name.
// TODO(bixia): We currently only handlge tag.type == "node".
//
// In the error message, a graph node is represented as {{tag.type, tag.name}},
// to allow a Python debugger to insert source information about the graph node.
// For example, a Python add expression may be represented as
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
// in tensorflow/python/framework/error_interpolation.py for more detail.
static std::string InterpolateErrorMessage(std::string message) {
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
return message;
}
Status Main(const MainFlags& flags) {
std::call_once(targets_init, &InitializeTargets);
absl::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;
@ -194,8 +215,13 @@ Status Main(const MainFlags& flags) {
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
Status status =
CompileGraph(std::move(graph_def), config, flags, &compile_result);
if (!status.ok()) {
return Status(status.code(),
InterpolateErrorMessage(status.error_message()));
}
// Write output files.
Env* env = Env::Default();

View File

@ -35,7 +35,6 @@ struct CompileResult {
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
xla::ProgramShapeProto program_shape; // Static shape of args and results.
string entry_point; // Name of generated function.
string tensorflow_header_root; // Prefix for tensorflow headers.
int pointer_size = 0; // Size of a pointer in bytes.
};

View File

@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info", &flags->debug_info,
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
"If not none, only keep the file path in the debug information after the"
" marker. The default value is empty"},
{"config", &flags->config,
"Input file containing Config proto. If the file ends in '.pbtxt' it "
"is expected to be in the human-readable proto text format, otherwise "
@ -70,12 +77,12 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Output session module proto."},
{"mlir_components", &flags->mlir_components,
"The MLIR components to enable. Currently only Bridge is supported."},
{"quantize", &flags->quantize,
"If set, quantization will be applied before HLO code generation."},
{"gen_name_to_index", &flags->gen_name_to_index,
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
{"gen_program_shape", &flags->gen_program_shape,
"Generate program shape data for the ProgramShape method."},
{"tensorflow_header_root", &flags->tensorflow_header_root,
"Root directory of tensorflow headers."},
};
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
}

View File

@ -28,6 +28,8 @@ namespace tfcompile {
struct MainFlags {
string graph;
string debug_info;
string debug_info_path_begin_marker;
string config;
bool dump_fetch_nodes = false;
string target_triple;
@ -40,7 +42,7 @@ struct MainFlags {
string out_header;
string out_session_module;
string mlir_components;
string tensorflow_header_root;
bool quantize = false;
// C++ codegen options
bool gen_name_to_index = false;

View File

@ -1,11 +1,37 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
glob_lit_tests(
data = [":filecheck_test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"test_error_message.lit.pbtxt": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
},
test_file_exts = ["lit.pbtxt"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "filecheck_test_utilities",
testonly = True,
srcs = [
"test_error_message.lit.pbtxt.config.pbtxt",
"test_error_message.lit.pbtxt.debug.pbtxt",
"test_error_message.lit.pbtxt.fake_py.debug",
],
data = [
"//tensorflow/compiler/aot:tfcompile",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
# We disable some tfcompile tests in the open source build with the
# "manual" tag to avoid making our OSS users build LLVM twice
# (once for host and once for target).
@ -25,6 +51,7 @@ test_suite(
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
":test_graph_tftop_k_test",
":test_graph_tfvariable_readonly_test",
":test_graph_tfvariable_sequential_updates_test",
":test_graph_tfvariable_test",
":tfcompile_test",
@ -59,6 +86,7 @@ genrule(
testonly = 1,
outs = [
"test_graph_tfadd.pb",
"test_debuginfo_tfadd.pb",
"test_graph_tfadd_with_ckpt.ckpt",
"test_graph_tfadd_with_ckpt.pb",
"test_graph_tfadd_with_ckpt_saver.ckpt",
@ -73,6 +101,7 @@ genrule(
"test_graph_tfsplits.pb",
"test_graph_tftop_k.pb",
"test_graph_tfvariable.pb",
"test_graph_tfvariable_readonly.pb",
"test_graph_tfvariable_sequential_updates.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
@ -96,6 +125,7 @@ tf_library(
# compile but the others in this directory succeed, you may need to
# expand the "required by all tf_library targets" list in tfcompile.bzl.
include_standard_runtime_deps = False,
mlir_components = "None",
tags = [
"manual",
],
@ -108,6 +138,7 @@ tf_library(
cpp_class = "AddWithCkptComp",
freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt",
graph = "test_graph_tfadd_with_ckpt.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -121,6 +152,7 @@ tf_library(
freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt",
freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver",
graph = "test_graph_tfadd_with_ckpt_saver.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -132,6 +164,7 @@ tf_library(
config = "test_graph_tfassert_eq.config.pbtxt",
cpp_class = "AssertComp",
graph = "test_graph_tfassert_eq.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -143,6 +176,7 @@ tf_library(
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -154,6 +188,7 @@ tf_library(
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -165,6 +200,7 @@ tf_library(
config = "test_graph_tfgather.config.pbtxt",
cpp_class = "GatherComp",
graph = "test_graph_tfgather.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -176,6 +212,7 @@ tf_library(
config = "test_graph_tfmatmul.config.pbtxt",
cpp_class = "foo::bar::MatMulComp",
graph = "test_graph_tfmatmul.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -187,6 +224,7 @@ tf_library(
config = "test_graph_tfmatmulandadd.config.pbtxt",
cpp_class = "::foo::bar::MatMulAndAddComp",
graph = "test_graph_tfmatmulandadd.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -200,6 +238,7 @@ tf_library(
cpp_class = "MatMulAndAddCompWithProfiling",
enable_xla_hlo_profiling = True,
graph = "test_graph_tfmatmulandadd.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -211,6 +250,7 @@ tf_library(
config = "test_graph_tfsplits.config.pbtxt",
cpp_class = "SplitsComp",
graph = "test_graph_tfsplits.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -222,6 +262,7 @@ tf_library(
config = "test_graph_tftop_k.config.pbtxt",
cpp_class = "TopKComp",
graph = "test_graph_tftop_k.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -233,6 +274,19 @@ tf_library(
config = "test_graph_tfvariable.config.pbtxt",
cpp_class = "VariableComp",
graph = "test_graph_tfvariable.pb",
mlir_components = "None",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_readonly",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -244,6 +298,7 @@ tf_library(
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
cpp_class = "VariableSequentialUpdatesComp",
graph = "test_graph_tfvariable_sequential_updates.pb",
mlir_components = "None",
tags = [
"manual",
],
@ -269,6 +324,7 @@ tf_cc_test(
":test_graph_tfsplits",
":test_graph_tftop_k",
":test_graph_tfvariable",
":test_graph_tfvariable_readonly",
":test_graph_tfvariable_sequential_updates",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -288,6 +344,7 @@ tf_library(
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
debug_info = "test_debuginfo_tfadd.pb",
graph = "test_graph_tfadd.pb",
include_standard_runtime_deps = False,
mlir_components = "Bridge",
@ -335,6 +392,18 @@ tf_library(
],
)
tf_library(
name = "test_graph_tffunction_mlir_bridge",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfassert_eq_mlir_bridge",
testonly = 1,
@ -421,6 +490,42 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfvariable_readonly_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable.config.pbtxt",
cpp_class = "VariableComp",
graph = "test_graph_tfvariable.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_sequential_updates_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
cpp_class = "VariableSequentialUpdatesComp",
graph = "test_graph_tfvariable_sequential_updates.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_cc_test(
name = "tfcompile_test_mlir_bridge",
srcs = ["tfcompile_test.cc"],
@ -434,12 +539,16 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
":test_graph_tfassert_eq_mlir_bridge",
":test_graph_tfcond_mlir_bridge",
":test_graph_tffunction_mlir_bridge",
":test_graph_tfgather_mlir_bridge",
":test_graph_tfmatmul_mlir_bridge",
":test_graph_tfmatmulandadd_mlir_bridge",
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
":test_graph_tfsplits_mlir_bridge",
":test_graph_tftop_k_mlir_bridge",
":test_graph_tfvariable_mlir_bridge",
":test_graph_tfvariable_readonly_mlir_bridge",
":test_graph_tfvariable_sequential_updates_mlir_bridge",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -154,11 +155,22 @@ def tftop_k(_):
array_ops.identity(output[1], name='indices')
def tfvariable(_):
def tfvariable_readonly(_):
x = variables.Variable(1000.0, name='x')
unused_y = variables.Variable(1000.0, name='y')
old_x = x.value()
with ops.control_dependencies([old_x]):
new_x = x.assign_add(42.0)
new_value = math_ops.add(old_x, 42.0)
array_ops.identity(new_value, name='result')
# TODO(b/147908587): Change x and the two constants back to have a scalar shape
# when the bug is fixed.
def tfvariable(_):
x = variables.Variable([1000.0], name='x', shape=[1])
old_x = x.value()
with ops.control_dependencies([old_x]):
new_x = x.assign_add([42.0])
array_ops.stack([old_x, new_x], name='result')
@ -174,7 +186,22 @@ def tfvariable_sequential_updates(_):
array_ops.identity(updates, name='result')
def write_graph(build_graph, out_dir):
def export_debug_info(exported_graph):
"""Exports debug information from a graph.
Args:
exported_graph: A Graph that has been created by tracing a saveable view.
Returns:
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
"""
exported_operations = []
for op in exported_graph.get_operations():
exported_operations.append(('', op))
return error_interpolation.create_graph_debug_info_def(exported_operations)
def write_graph(build_graph, out_dir, debug_info=False):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
with g.as_default():
@ -183,10 +210,19 @@ def write_graph(build_graph, out_dir):
with open(filename, 'wb') as f:
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
if debug_info:
filename_debuginfo = os.path.join(
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
test_debuginfo = export_debug_info(g)
with open(filename_debuginfo, 'wb') as f:
f.write(
six.ensure_binary(
test_debuginfo.SerializeToString(deterministic=True)))
def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)
@ -198,6 +234,7 @@ def main(_):
write_graph(tfsplits, FLAGS.out_dir)
write_graph(tftop_k, FLAGS.out_dir)
write_graph(tfvariable, FLAGS.out_dir)
write_graph(tfvariable_readonly, FLAGS.out_dir)
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)

View File

@ -0,0 +1,69 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
# Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and
# the node x_y_sum = Add
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
# CHECK: math_ops.add(x, y, name='x_y_sum')
# CHECK: build_graph(out_dir)
# Checks the error message produced by tfcompile without mlir_component
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
# OLD: x_y_sum
node: {
name: "x"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "y"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "x_y_sum"
op: "Add"
input: "x"
input: "y"
attr: {
key: "T"
value: {
type: DT_INT32
}
}
}
versions: {
producer: 321
}

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x" }
shape {
dim { size: 2 }
}
}
feed {
id { node_name: "y" }
shape {
dim { size: 3 }
}
}
fetch {
id { node_name: "x_y_sum" }
}

View File

@ -0,0 +1,28 @@
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
traces: {
key: "x@"
value: {
file_line_cols: {
line: 1
}
}
}
traces: {
key: "x_y_sum@"
value: {
file_line_cols: {
line: 3
}
file_line_cols: {
line: 4
}
}
}
traces: {
key: "y@"
value: {
file_line_cols: {
line: 2
}
}
}

View File

@ -0,0 +1,4 @@
x = value
y = value
math_ops.add(x, y, name='x_y_sum')
build_graph(out_dir)

View File

@ -0,0 +1,20 @@
# Text form of tensorflow.tf2xla.Config proto.
fetch {
id { node_name: "result" }
}
variable {
node_name: "x"
shape {
}
type: DT_FLOAT
readonly: true
}
variable {
node_name: "y"
shape {
}
type: DT_FLOAT
readonly: true
}

View File

@ -32,12 +32,16 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
#else
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@ -52,6 +56,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
#endif
@ -425,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
}
}
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
@ -441,7 +444,6 @@ TEST(TFCompileTest, Function) {
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}
#endif
TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1);
@ -495,8 +497,20 @@ TEST(TFCompileTest, TopK) {
EXPECT_EQ(expected_indices[1], fn.result1(1));
}
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, VariableReadonly) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
VariableReadonlyComp fn;
float x = 23;
fn.set_var_x_data(&x);
fn.set_thread_pool(&device);
fn.Run();
EXPECT_EQ(fn.result0(), 65);
EXPECT_EQ(fn.var_x(), 23);
}
TEST(TFCompileTest, Variable) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
@ -569,7 +583,6 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
fn.Run();
EXPECT_NEAR(x, 0.594322f, 1e-6);
}
#endif
TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the

View File

@ -26,6 +26,7 @@ def tf_library(
name,
graph,
config,
debug_info = None,
freeze_checkpoint = None,
freeze_saver = None,
cpp_class = None,
@ -37,7 +38,7 @@ def tf_library(
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps = True,
enable_xla_hlo_profiling = False,
mlir_components = None,
mlir_components = "None",
deps = None,
tags = []):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
@ -88,8 +89,8 @@ def tf_library(
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
program, and emit metadata that lets us pretty-print the gathered
profile counters.
mlir_components: When the value is "Bridge", use MLIR to translate
GraphDef to HLO.
mlir_components: When the value is "None", no components use MLIR. When
the value is "Bridge", use MLIR to translate GraphDef to HLO.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
@ -189,17 +190,17 @@ def tf_library(
else:
profiling_flag = ""
if mlir_components:
mlir_flag = "--mlir_components=" + mlir_components
else:
mlir_flag = ""
mlir_flag = "--mlir_components=" + mlir_components
srcs = [tfcompile_graph, config]
debug_info_flag = ""
if debug_info:
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
native.genrule(
name = ("gen_" + name),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
header_file,
metadata_object_file,
@ -209,6 +210,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
@ -240,10 +242,7 @@ def tf_library(
session_module_pb = name + "_session_module.pb"
native.genrule(
name = (name + "_session_module"),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
session_module_pb,
],
@ -251,6 +250,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
@ -410,5 +410,6 @@ def target_llvm_triple():
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:macos": "x86_64-none-darwin",
"//tensorflow:windows": "x86_64-none-windows",
"//conditions:default": "x86_64-pc-linux",
})

View File

@ -65,7 +65,7 @@ int main(int argc, char** argv) {
flags.out_metadata_object = "out_helper.o";
flags.out_header = "out.h";
flags.entry_point = "entry";
flags.tensorflow_header_root = "third_party/tensorflow";
flags.debug_info_path_begin_marker = "";
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
@ -82,12 +82,10 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
"other than flags\n\n"
<< usage;
"other than flags. See --help.\n\n";
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
<< usage;
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n";
return 1;
} else {
TF_QCHECK_OK(status);

View File

@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags")
package(
default_visibility = [":internal"],
@ -13,6 +14,10 @@ package_group(
includes = [
"//tensorflow/compiler/tf2xla:internal",
],
packages = [
"//tensorflow/compiler/tests/...",
"//tensorflow/python/...",
],
)
package_group(
@ -56,6 +61,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -69,6 +75,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
@ -77,19 +84,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_mlir_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
]),
alwayslink = 1,
)
cc_library(
name = "xla_cpu_device",
srcs = ["xla_cpu_device.cc"],
@ -124,6 +118,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:gpu_init",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -168,7 +163,9 @@ XLA_DEVICE_DEPS = [
":common",
":xla_launch_util",
":xla_tensor",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"//tensorflow/compiler/jit/ops:xla_ops",
@ -187,6 +184,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -261,13 +259,26 @@ cc_library(
}),
)
# Internal targets below this point.
cc_library(
name = "flags",
srcs = ["flags.cc"],
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
)
# Header-only version of "flags" library, for linking from the shared object
# without ODR violations.
cc_library(
name = "flags_headers_only",
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
@ -287,6 +298,8 @@ cc_library(
visibility = [":friends"],
)
# Internal targets below this point.
cc_library(
name = "xla_launch_util",
srcs = ["xla_launch_util.cc"],
@ -325,6 +338,7 @@ cc_library(
deps = [
":xla_activity_listener",
":xla_activity_proto_cc",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
@ -408,6 +422,7 @@ cc_library(
"xla_kernel_creator.h",
],
deps = [
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",
@ -636,6 +651,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
@ -767,7 +783,7 @@ tf_cc_test(
],
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
# error.
tags = ["nomsan"],
tags = ["nomsan"] + tf_cuda_tests_tags(),
deps = [
":common",
":compilation_passes",

Some files were not shown because too many files have changed in this diff Show More