Merge branch 'master' into update-array-ops-docstrings

This commit is contained in:
Mihai Maruseac 2020-04-02 02:25:32 +00:00 committed by GitHub
commit 90f5f406db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8809 changed files with 576245 additions and 403294 deletions

116
.bazelrc
View File

@ -19,10 +19,10 @@
# Compiler options:
# cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options
# C++1z: Build with C++17 options
# c++1z: Build with C++17 options
# avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux.
# arch_native_linux: Build with instruction sets available to the host machine on linux
# native_arch_linux: Build with instruction sets available to the host machine on linux
# avx_win: Build with avx instruction set on windows
# avx2_win: Build with avx2 instruction set on windows
#
@ -46,7 +46,6 @@
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support.
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
# tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
# numa: Enable numa using hwloc.
@ -69,6 +68,7 @@
# rbe_linux_py3: Linux Python 3 RBE config
#
# rbe_win_py37: Windows Python 3.7 RBE config
# rbe_win_py38: Windows Python 3.8 RBE config
#
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
@ -136,15 +136,9 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -221,6 +215,11 @@ build --define=grpc_no_ares=true
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
# https://github.com/tensorflow/community/pull/179
build --noincompatible_prohibit_aapt1
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
@ -238,6 +237,11 @@ build:linux --copt=-w
build:macos --copt=-w
build:windows --copt=/w
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined.
build:windows --copt=/D_USE_MATH_DEFINES
build:windows --host_copt=/D_USE_MATH_DEFINES
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr
build:linux --define=LIBDIR=$(PREFIX)/lib
@ -258,6 +262,9 @@ build:windows --host_cxxopt=/std:c++14
# On windows, we still link everything into a single DLL.
build:windows --config=monolithic
# On linux, we dynamically link small amount of kernels
build:linux --config=dynamic_kernels
# Make sure to include as little of windows.h as possible
build:windows --copt=-DWIN32_LEAN_AND_MEAN
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
@ -272,7 +279,6 @@ build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true
build:windows --incompatible_windows_native_test_wrapper
# Verbose failure logs when something goes wrong
build:windows --verbose_failures
@ -307,22 +313,26 @@ build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
# Options when using remote execution
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
# Flag to enable remote config
common --experimental_repo_remote_exec
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --auth_enabled=true
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
build:rbe --google_default_credentials
build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_best_effort=false
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
build:rbe --distinct_host_configuration=false
build:rbe --flaky_test_attempts=3
build:rbe --jobs=200
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
build:rbe --remote_timeout=3600
build:rbe --spawn_strategy=remote,worker,standalone,local
test:rbe --test_env=USER=anon
build:rbe --distinct_host_configuration=false
# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:rbe --remote_download_toplevel
build:rbe_linux --config=rbe
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
@ -337,6 +347,7 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
@ -345,21 +356,37 @@ build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/to
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --config=rbe_linux
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_CUDA=1
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
@ -369,29 +396,33 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
build:rbe_linux_py3 --config=rbe_linux
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:toolchain"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
@ -401,7 +432,6 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
build:tensorflow_testing_rbe_win --config=rbe_win
# END TF REMOTE BUILD EXECUTION OPTIONS
# Default options should come above this line

View File

@ -1 +1 @@
1.1.0
2.0.0

44
.github/ISSUE_TEMPLATE/00-bug-issue.md vendored Normal file
View File

@ -0,0 +1,44 @@
---
name: Bug Issue
about: Use this template for reporting a bug
labels: 'type:bug'
---
<em>Please make sure that this is a bug. As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

View File

@ -1,35 +0,0 @@
---
name: Bug/Performance Issue
about: Use this template for reporting a bug or a performance issue.
---
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
**Other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -1,6 +1,7 @@
---
name: Build/Installation Issue
about: Use this template for build/installation issues
labels: 'type:build/install'
---

View File

@ -1,10 +1,11 @@
---
name: Documentation Issue
about: Use this template for documentation related
about: Use this template for documentation related issues
labels: 'type:docs'
---
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
policy, we only address code/doc bugs, performance issues, feature requests, and
build/installation issues on GitHub.

View File

@ -1,6 +1,7 @@
---
name: Feature Request
about: Use this template for raising a feature request
labels: 'type:feature'
---

View File

@ -1,10 +1,10 @@
---
name: TensorFlow Lite Op Request
about: Use this template for reporting ops you are using or missing.
about: Use this template for reporting Lite ops you are using or missing
labels: 'comp:lite'
---
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
@ -17,8 +17,14 @@ about: Use this template for reporting ops you are using or missing.
# Copy and paste here
```
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Also, please include a link to a GraphDef or the model if possible.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

View File

@ -1,6 +1,7 @@
---
name: Other Issues
about: Use this template for any other non-support related issues
labels: 'type:others'
---

View File

@ -1,6 +1,7 @@
---
name: TensorFlow Lite New Converter Issue
about: Use this template for reporting issues during model conversion to TFLite.
about: Use this template for reporting issues during model conversion to TFLite
labels: 'TFLiteConverter'
---
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
**Command used to run the converter or code if youre using the Python API**
If possible, please share a link to Colab/Jupyter/any notebook.
```
# Copy and paste here the exact command

View File

@ -0,0 +1,19 @@
---
name: TensorFlow Lite for Microcontrollers Issue
about: Use this template for reporting issues with TensorFlow Lite for microcontrollers
labels: 'comp:micro'
---
@tensorflow/micro
**System information**
- Host OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
- Tensorflow version (commit SHA if source):
- Target platform (e.g. Arm Mbed OS, Arduino Nano 33 etc.):
**Describe the problem**
**Please provide the exact sequence of commands/steps when you ran into the problem**

View File

@ -0,0 +1,45 @@
---
name: Performance Issue
about: Use this template for reporting a performance issue
labels: 'type:performance'
---
<em>Please make sure that this is an issue related to performance of TensorFlow.
As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

5
.gitignore vendored
View File

@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
/tensorflow/python/framework/fast_tensor_util.cpp
/tensorflow/lite/gen/**
/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/tools/make/gen/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
*.whl
@ -37,7 +38,9 @@ gradleBuild
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
/tensorflow/lite/**/ios/BUILD
/tensorflow/lite/**/objc/BUILD
/tensorflow/lite/**/swift/BUILD
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/examples/ios/simple/data/*.txt
Podfile.lock

1
.pylintrc Symbolic link
View File

@ -0,0 +1 @@
tensorflow/tools/ci_build/pylintrc

View File

@ -13,55 +13,4 @@
/tensorflow/tensorboard/ @jart
/tensorflow/tools/docs/ @markdaoust
# contrib
# NEED OWNER: /tensorflow/contrib/all_reduce
/tensorflow/contrib/autograph/ @mdanatg @kkimdev
/tensorflow/contrib/batching/ @alextp @chrisolston
/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon
/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva
/tensorflow/contrib/checkpoint/ @allenlavoie
/tensorflow/contrib/contrib/cluster_resolver/ @frankchn
/tensorflow/contrib/cmake/ @mrry
/tensorflow/contrib/copy_graph/ @tucker @poxvoculi
/tensorflow/contrib/crf/ @kentonl
/tensorflow/contrib/data/ @mrry
/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
/tensorflow/contrib/eager @jaingaurav @alextp
/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
/tensorflow/contrib/ffmpeg/ @fredbertsch
/tensorflow/contrib/framework/ @ebrevdo
/tensorflow/contrib/graph_editor/ @purpledog
# NEED OWNER: /tensorflow/contrib/grid_rnn/
/tensorflow/contrib/hadoop @yongtang
/tensorflow/contrib/hvx/ @satok16
/tensorflow/contrib/integrate/ @shoyer
/tensorflow/contrib/kernel_methods/ @petrosmol
/tensorflow/contrib/ios_examples/ @petewarden
/tensorflow/contrib/labeled_tensor/ @shoyer
/tensorflow/contrib/layers/ @fchollet @martinwicke
/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
/tensorflow/contrib/lookup/ @ysuematsu @andreasst
/tensorflow/contrib/losses/ @alextp @ispirmustafa
/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
/tensorflow/contrib/opt/ @strategist333 @alextp
/tensorflow/contrib/pi_examples/ @maciekcc
/tensorflow/contrib/quantization/ @petewarden
/tensorflow/contrib/rnn/ @ebrevdo @scottzhu
/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie
/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang
/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
/tensorflow/contrib/slim/ @sguada @thenbasilmanran
/tensorflow/contrib/stateless/ @girving @alextp
/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2
# NEED OWNER: /tensorflow/contrib/testing/
/tensorflow/contrib/timeseries/ @allenlavoie
/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
/tensorflow/contrib/training/ @joel-shor @ebrevdo
/tensorflow/contrib/util/ @sherrym
/third_party/systemlibs/ @perfinion

View File

@ -72,7 +72,7 @@ TensorFlow coding style.
[tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core)
and
[tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python).
TensorFlow has reached version 1 and hence cannot make
TensorFlow has passed version 1.0 and hence cannot make
non-backward-compatible API changes without a major release. Reviewers of
your pull request will comment on any API compatibility issues.
* When you contribute a new feature to TensorFlow, the maintenance burden is
@ -88,6 +88,9 @@ TensorFlow coding style.
submitting PRs to fix one typo, one warning,etc. We recommend fixing the
same issue at the file level at least (e.g.: fix all typos in a file, fix
all compiler warning in a file, etc.)
* Tests should follow the
[testing best practices](https://www.tensorflow.org/community/contribute/tests)
guide.
#### License

View File

@ -37,23 +37,26 @@ See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
[Docker container](https://www.tensorflow.org/install/docker), and
[build from source](https://www.tensorflow.org/install/source).
To install the current release for CPU-only:
To install the current release, which includes support for
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and
Windows)*:
```
$ pip install tensorflow
```
Use the GPU package for
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and
Windows)*:
A smaller CPU-only package is also available:
```
$ pip install tensorflow-gpu
$ pip install tensorflow-cpu
```
To update TensorFlow to the latest version, add `--upgrade` flag to the above
commands.
*Nightly binaries are available for testing using the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) packages on PyPi.*
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
#### *Try your first TensorFlow program*
@ -67,7 +70,7 @@ $ python
3
>>> hello = tf.constant('Hello, TensorFlow!')
>>> hello.numpy()
'Hello, TensorFlow!'
b'Hello, TensorFlow!'
```
For more examples, see the
@ -110,35 +113,38 @@ Build Type | Status
### Community Supported Builds
Build Type | Status | Artifacts
------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, 3.6 and 3.7** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.14.0 PyPI](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
Build Type | Status | Artifacts
----------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
## Resources
* [TensorFlow.org](https://www.tensorflow.org)
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow examples](https://github.com/tensorflow/examples)
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow Examples](https://github.com/tensorflow/examples)
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [TensorFlow blog](https://blog.tensorflow.org)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the
[TensorFlow community](https://www.tensorflow.org/community) and how to

File diff suppressed because one or more lines are too long

View File

@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).

View File

@ -1,6 +1,6 @@
workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_closure",
@ -48,38 +48,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.6.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
@ -145,3 +113,32 @@ http_archive(
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")
load("//third_party/googleapis:repository_rules.bzl", "config_googleapis")
config_googleapis()

View File

@ -33,7 +33,7 @@ except ImportError:
from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top
_DEFAULT_CUDA_VERSION = '10.1'
_DEFAULT_CUDA_VERSION = '10'
_DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_TENSORRT_VERSION = '6'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.0.0'
_TF_MAX_BAZEL_VERSION = '1.1.0'
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -147,14 +147,16 @@ def write_action_env_to_bazelrc(var_name, var):
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
def run_shell(cmd, allow_non_zero=False):
def run_shell(cmd, allow_non_zero=False, stderr=None):
if stderr is None:
stderr = sys.stdout
if allow_non_zero:
try:
output = subprocess.check_output(cmd)
output = subprocess.check_output(cmd, stderr=stderr)
except subprocess.CalledProcessError as e:
output = e.output
else:
output = subprocess.check_output(cmd)
output = subprocess.check_output(cmd, stderr=stderr)
return output.decode('UTF-8').strip()
@ -169,10 +171,12 @@ def get_python_path(environ_cp, python_bin_path):
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
try:
stderr = open(os.devnull, 'wb')
library_paths = run_shell([
python_bin_path, '-c',
'import site; print("\\n".join(site.getsitepackages()))'
]).split('\n')
],
stderr=stderr).split('\n')
except subprocess.CalledProcessError:
library_paths = [
run_shell([
@ -1151,7 +1155,7 @@ def set_trisycl_include_dir(environ_cp):
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def system_specific_test_config(env):
def system_specific_test_config(environ_cp):
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
write_to_bazelrc('test --test_size_filters=small,medium')
@ -1167,22 +1171,29 @@ def system_specific_test_config(env):
test_only_filters = ['-oss_serial']
if is_windows():
test_and_build_filters.append('-no_windows')
if env.get('TF_NEED_CUDA', None) == '1':
if environ_cp.get('TF_NEED_CUDA', None) == '1':
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
else:
test_and_build_filters.append('-gpu')
elif is_macos():
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
elif is_linux():
if env.get('TF_NEED_CUDA', None) == '1':
if environ_cp.get('TF_NEED_CUDA', None) == '1':
test_and_build_filters.append('-no_gpu')
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
else:
test_and_build_filters.append('-gpu')
write_to_bazelrc('test --test_tag_filters=%s' %
# Disable tests with "v1only" tag in "v2" Bazel config, but not in "v1" config
write_to_bazelrc('test:v1 --test_tag_filters=%s' %
','.join(test_and_build_filters + test_only_filters))
write_to_bazelrc('test --build_tag_filters=%s' %
write_to_bazelrc('test:v1 --build_tag_filters=%s' %
','.join(test_and_build_filters))
write_to_bazelrc(
'test:v2 --test_tag_filters=%s' %
','.join(test_and_build_filters + test_only_filters + ['-v1only']))
write_to_bazelrc('test:v2 --build_tag_filters=%s' %
','.join(test_and_build_filters + ['-v1only']))
def set_system_libs_flag(environ_cp):
@ -1210,7 +1221,7 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/g/build/c/SsW98Eo7l3o
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be
@ -1379,9 +1390,8 @@ def main():
else:
environ_cp['TF_CONFIGURE_IOS'] = '0'
xla_enabled_by_default = is_linux() or is_macos()
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
xla_enabled_by_default, 'xla')
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla')
set_action_env_var(
environ_cp,
@ -1512,7 +1522,7 @@ def main():
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
system_specific_test_config(os.environ)
system_specific_test_config(environ_cp)
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
if environ_cp.get('TF_CONFIGURE_IOS') == '1':

View File

@ -2,6 +2,7 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",
@ -186,6 +187,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "fuchsia",
values = {"cpu": "fuchsia"},
visibility = ["//visibility:public"],
)
config_setting(
name = "ios_x86_64",
values = {
@ -195,6 +202,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "chromiumos",
values = {"crosstool_top": "//external:android/chromiumos"},
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_aarch64",
values = {"cpu": "aarch64"},
@ -441,18 +454,66 @@ config_setting(
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later.
selects.config_setting_group(
name = "mobile",
match_any = [
":android",
":chromiumos",
":emscripten",
":ios",
],
)
config_setting(
name = "lite_protos_legacy",
values = {"define": "TENSORFLOW_PROTOS=lite"},
visibility = ["//visibility:private"],
)
config_setting(
name = "full_protos",
values = {"define": "TENSORFLOW_PROTOS=full"},
visibility = ["//visibility:public"],
)
selects.config_setting_group(
name = "lite_protos",
match_any = [":lite_protos_legacy"],
)
selects.config_setting_group(
name = "mobile_lite_protos",
match_all = [
":lite_protos",
":mobile",
],
)
selects.config_setting_group(
name = "mobile_full_protos",
match_all = [
":full_protos",
":mobile",
],
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
],
)
@ -471,7 +532,8 @@ bzl_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/platform:build_config_root_bzl",
"//tensorflow/core/platform:cuda_build_defs_bzl",
"//tensorflow/core/platform:rules_cc_bzl",
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl",
"//third_party/ngraph:build_defs_bzl",
@ -485,8 +547,8 @@ cc_library(
name = "grpc",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc_unsecure"],
"//conditions:default": ["@grpc"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
}),
)
@ -494,8 +556,8 @@ cc_library(
name = "grpc++",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc++_unsecure"],
"//conditions:default": ["@grpc//:grpc++"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
}),
)
@ -580,6 +642,7 @@ tf_cc_shared_object(
"//tensorflow/core:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler:profiler_impl",
"//tensorflow/stream_executor:stream_executor_impl",
"//tensorflow:tf_framework_version_script.lds",
] + tf_additional_binary_deps(),
@ -639,6 +702,7 @@ tf_cc_shared_object(
"//tensorflow/c:exported_symbols.lds",
"//tensorflow/c:version_script.lds",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:tensorflow",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
],
@ -853,7 +917,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V1,
output_package = "tensorflow._api.v1",
root_file_name = "v1.py",
root_init_template = "api_template_v1.__init__.py",
root_init_template = "$(location api_template_v1.__init__.py)",
)
gen_api_init_files(
@ -876,7 +940,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2",
root_file_name = "v2.py",
root_init_template = "api_template.__init__.py",
root_init_template = "$(location api_template.__init__.py)",
)
py_library(
@ -899,7 +963,6 @@ py_library(
"//conditions:default": [":tf_python_api_gen_v1"],
}) + [
":root_init_gen",
":virtual_root_init_gen",
"//tensorflow/python/keras/api:keras_python_api_gen",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",

View File

@ -23,10 +23,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
app.flags = flags

View File

@ -13,16 +13,16 @@
# limitations under the License.
# ==============================================================================
"""
Top-level module of TensorFlow. By convention, we refer to this module as
`tf` instead of `tensorflow`, following the common practice of importing
Top-level module of TensorFlow. By convention, we refer to this module as
`tf` instead of `tensorflow`, following the common practice of importing
TensorFlow via the command `import tensorflow as tf`.
The primary function of this module is to import all of the public TensorFlow
interfaces into a single place. The interfaces themselves are located in
The primary function of this module is to import all of the public TensorFlow
interfaces into a single place. The interfaces themselves are located in
sub-modules, as described below.
Note that the file `__init__.py` in the TensorFlow source code tree is actually
only a placeholder to enable test cases to run. The TensorFlow build replaces
Note that the file `__init__.py` in the TensorFlow source code tree is actually
only a placeholder to enable test cases to run. The TensorFlow build replaces
this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py)
"""
@ -35,9 +35,16 @@ import inspect as _inspect
import logging as _logging
import os as _os
import site as _site
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# Make sure code inside the TensorFlow codebase can use tf2.enabled() at import.
_os.environ['TF2_BEHAVIOR'] = '1'
from tensorflow.python import tf2 as _tf2
_tf2.enable()
# API IMPORTS PLACEHOLDER
@ -69,13 +76,13 @@ except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v2 import keras
@ -85,10 +92,18 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
_major_api_version = 2
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
@ -119,8 +134,14 @@ def _running_from_pip_package():
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# TODO(gunan): Add sanity checks to loaded modules here.
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)
# Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)

View File

@ -22,12 +22,14 @@ import distutils as _distutils
import inspect as _inspect
import os as _os
import site as _site
import six as _six
import sys as _sys
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v1 import keras
@ -80,6 +83,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """
@ -104,6 +114,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)
_major_api_version = 1
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
@ -132,8 +144,14 @@ def _running_from_pip_package():
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# TODO(gunan): Add sanity checks to loaded modules here.
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)
# Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)

View File

@ -23,6 +23,7 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"tensor_interface.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_file_statistics.h",
@ -53,6 +54,22 @@ filegroup(
visibility = ["//visibility:public"],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"python_api.h",
"tensor_interface.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library(
name = "c_api_internal",
hdrs = [
@ -84,6 +101,17 @@ tf_cuda_library(
],
)
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
@ -108,6 +136,7 @@ tf_cuda_library(
":tf_attrtype",
":tf_status_internal",
":tf_file_statistics",
":tf_tensor_internal",
] + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -127,7 +156,10 @@ tf_cuda_library(
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
visibility = [
"//tensorflow/c:__subpackages__",
"//third_party/llvm/llvm-project:__subpackages__",
],
deps = [
":c_api_internal",
":tf_attrtype",
@ -208,6 +240,16 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "tensor_interface",
hdrs = ["tensor_interface.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tf_datatype",
srcs = ["tf_datatype.cc"],
@ -215,7 +257,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -234,6 +276,7 @@ cc_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
@ -241,6 +284,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
],
}),
)
@ -251,14 +295,18 @@ tf_cuda_library(
"tf_tensor.h",
"tf_tensor_internal.h",
],
visibility = ["//tensorflow:internal"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
],
}),
)
@ -285,6 +333,9 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"@com_google_absl//absl/strings",
@ -506,6 +557,7 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:math",
"//tensorflow/core/platform:resource_loader",
],
)
@ -518,6 +570,7 @@ tf_cc_test(
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = ["notsan"], # b/149031034
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@ -616,6 +669,7 @@ tf_cuda_cc_test(
deps = [
":c_api",
":kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -623,6 +677,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3",
"@com_google_absl//absl/container:inlined_vector",
],
)
@ -664,4 +719,5 @@ tf_cuda_library(
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
"//tensorflow/python:cpp_shape_inference_proto_cc",
],
alwayslink = 1,
)

View File

@ -56,16 +56,16 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/coding.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
@ -458,7 +458,7 @@ static void TF_Run_Helper(
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
continue;
}
c_outputs[i] = TF_TensorFromTensor(src, status);
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
if (!status->status.ok()) return;
}
}
@ -774,7 +774,7 @@ extern "C" {
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
const char* op_type,
const char* oper_name)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
return new TF_OperationDescription(graph, op_type, oper_name);
}
@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
Node* ret = nullptr;
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return;
*value = TF_TensorFromTensor(t, status);
*value = TF_TensorFromTensor(t, &status->status);
}
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status);
values[i] = TF_TensorFromTensor(ts[i], &status->status);
}
}
@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
const TF_ImportGraphDefOptions* opts,
TF_ImportGraphDefResults* tf_results,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
const int last_node_id = graph->graph.num_node_ids();
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) {
DCHECK(status->status.ok());
*result = TF_TensorFromTensor(result_tensor, status);
*result = TF_TensorFromTensor(result_tensor, &status->status);
if (!status->status.ok()) evaluated = false;
}
return evaluated;

View File

@ -23,16 +23,19 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -518,72 +521,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
auto* status = TF_NewStatus();
TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::Tensor dst;
TF_CHECK_OK(TF_TensorToTensor(t, &dst));
LOG(INFO) << dst.DebugString();
TF_DeleteTensor(t);
TF_DeleteStatus(status);
}
void TFE_OpPrintDebugString(TFE_Op* op) {
VLOG(1) << "TFE_OpPrintDebugString() over " << op;
LOG(INFO) << op->operation.DebugString();
}
struct TFE_ExecuteOpNotification {
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
tensorflow::Notification n;
std::unique_ptr<tensorflow::Thread> thread;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
};
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TFE_TensorHandle** retvals,
int* num_retvals,
TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get());
n->n.Notify();
}));
return n;
}
void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status) {
if (notification == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification is a nullptr.");
return;
}
if (notification->thread == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification didn't start a thread correctly. Cleaning up "
"this notification. Please re-execute the operation to get a new "
"notification.");
delete notification;
return;
}
notification->n.WaitForNotification();
status->status = notification->status->status;
delete notification;
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
@ -634,7 +571,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor, status);
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
}
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
@ -747,7 +684,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
status->status = tensorflow::Status::OK();
return new TFE_TensorHandle{
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
}
namespace {
@ -767,8 +707,10 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} while (0);
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -779,12 +721,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
}
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
} else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
}
@ -880,12 +822,14 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
node_def.set_name(tfe_op->operation.Name());
node_def.set_op(tfe_op->operation.Name());
node_def.set_name(tfe_op->operation->Name());
node_def.set_op(tfe_op->operation->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
OperationFromInterface(tfe_op->operation)
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =

View File

@ -188,31 +188,6 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
// Prints `handle` in a human readable format to standard output for debugging.
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op);
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
// Allows invoking a kernel asynchronously, and explicitly returns a
// notification that can be waited upon. This always executes the kernel in a
// new thread.
// 1. `retvals` and `num_retvals` can only be consumed after
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
// if the return is unsuccessful
// 2. These new APIs cannot be used together with the TFE context level async
// support.
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status);
// Waits to complete the op execution, and cleans up the notification.
// Errors reported by op execution are set in `status`.
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);

View File

@ -84,127 +84,6 @@ TEST(CAPI_EXPERIMENTAL, IsStateful) {
EXPECT_EQ(id, 0);
}
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
auto* r =
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteOp(matmul_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
// Perform a send/recv test. Recv blocks, so they need to be executed
// asynchronously.
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
TFE_TensorHandle* m = TestMatrixTensorHandle();
// Build a send op.
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(send_op, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
string tensor_name = "Tensor";
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(send_op, "client_terminated", true);
// Build a recv op.
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
TFE_TensorHandle* send_retvals;
int send_num_retvals = 0;
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
&send_num_retvals, status);
TFE_TensorHandle* recv_retvals[1] = {nullptr};
int recv_num_retvals = 1;
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
&recv_num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(1, product[0]);
EXPECT_EQ(2, product[1]);
EXPECT_EQ(3, product[2]);
EXPECT_EQ(4, product[3]);
TFE_DeleteOp(send_op);
TFE_DeleteOp(recv_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(recv_retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
class ShapeInferenceTest : public ::testing::Test {
protected:
ShapeInferenceTest()

View File

@ -27,8 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/base64.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/base64.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::errors::InvalidArgument;
@ -51,7 +51,7 @@ Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
@ -87,7 +87,7 @@ Status ProcessInputs(
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
@ -111,7 +111,7 @@ Status ComputeBodyNodes(
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);

View File

@ -14,17 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -1260,11 +1259,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
&node3);
TF_Output inputs[] = {};
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 3, outputs,
/*opers=*/nullptr, 0, nullptr, 3, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1300,10 +1298,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
&node);
TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs,
/*opers=*/nullptr, 1, inputs, 0, nullptr,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1603,11 +1600,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
TF_Operation* random =
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
TF_Output inputs[] = {};
TF_Output outputs[] = {{random, 0}};
*func = TF_GraphToFunction(func_graph.get(), name,
/*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 1, outputs,
/*opers=*/nullptr, 0, nullptr, 1, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, "", s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());

View File

@ -40,8 +40,8 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
@ -71,14 +71,14 @@ struct TF_Graph {
TF_Graph();
tensorflow::mutex mu;
tensorflow::Graph graph GUARDED_BY(mu);
tensorflow::Graph graph TF_GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);
TF_GUARDED_BY(mu);
// The keys of this map are all the active sessions using this graph. Each
// value records whether the graph has been mutated since the corresponding
@ -94,8 +94,8 @@ struct TF_Graph {
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
// status, this should be reverted when possible.
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
TF_GUARDED_BY(mu);
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
@ -123,7 +123,7 @@ struct TF_Session {
tensorflow::Session* session;
TF_Graph* const graph;
tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
int last_num_graph_nodes;
// If true, TF_SessionRun and similar methods will call
@ -169,9 +169,9 @@ struct TF_ApiDefMap {
}
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
bool update_docs_called GUARDED_BY(lock);
bool update_docs_called TF_GUARDED_BY(lock);
tensorflow::mutex lock;
};
@ -188,7 +188,7 @@ namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);
@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
LOCKS_EXCLUDED(session->graph->mu, session->mu);
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
std::string getTF_OutputDebugString(TF_Output node);

View File

@ -43,15 +43,17 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
{
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
string lib_path = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
@ -227,7 +230,7 @@ TEST(CAPI, LibraryLoadFunctions) {
void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size();
TF_Status* status = TF_NewStatus();
Status status;
for (const std::vector<tensorflow::int64>& dims :
std::vector<std::vector<tensorflow::int64>>{
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
@ -236,8 +239,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
src.flat<tstring>()(i) = data[i];
}
TF_Tensor* dst = TF_TensorFromTensor(src, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* dst = TF_TensorFromTensor(src, &status);
ASSERT_TRUE(status.ok()) << status.error_message();
// Convert back to a C++ Tensor and ensure we get expected output.
Tensor output;
@ -249,7 +252,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
TF_DeleteTensor(dst);
}
TF_DeleteStatus(status);
}
TEST(CAPI, TensorEncodeDecodeStrings) {
@ -1351,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
TEST(CAPI, SavedModel) {
// Load the saved model.
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
TF_Buffer* metagraph = TF_NewBuffer();
@ -1394,8 +1396,9 @@ TEST(CAPI, SavedModel) {
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
Status status;
csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
ASSERT_TRUE(status.ok()) << status.error_message();
const tensorflow::string output_op_name(
tensorflow::ParseTensorName(output_name).first);
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
}
TEST(CAPI, SavedModelNullArgsAreValid) {
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Status* s = TF_NewStatus();
const char* tags[] = {tensorflow::kSavedModelTagServe};
@ -2522,12 +2525,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
// Take an unaligned slice.
Tensor y = x.Slice(1, 13);
TF_Status* status = TF_NewStatus();
TF_Tensor* a = TF_TensorFromTensor(y, status);
Status status;
TF_Tensor* a = TF_TensorFromTensor(y, &status);
if (EIGEN_MAX_ALIGN_BYTES > 0) {
EXPECT_FALSE(TF_TensorIsAligned(a));
}
TF_DeleteStatus(status);
TF_DeleteTensor(a);
}

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <memory.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <time.h>
#include <unistd.h>
#include "tensorflow/c/c_api.h"
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
}
char file_name[100];
struct timeval t;
if (gettimeofday(&t, NULL)) {
perror("gettimeofday failed");
return 1;
}
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
time_t t = time(NULL);
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
size_t length = 2 + strlen(path) + strlen(file_name);
char* full_path = malloc(length);

View File

@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/public/session_options.h"
using tensorflow::GraphDef;

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <unordered_set>
#include <utility>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
#include "tensorflow/core/util/tensor_slice_reader.h"

View File

@ -2,6 +2,7 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
"tf_cuda_cc_test",
"tf_cuda_library",
@ -26,8 +27,8 @@ tf_cuda_library(
"c_api.cc",
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h",
"c_api_unified_experimental.h",
],
hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(),
@ -37,17 +38,23 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":context_interface",
":operation_interface",
":tensor_handle_interface",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -78,16 +85,34 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.h",
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api_experimental.h"],
srcs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
],
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__subpackages__",
@ -95,6 +120,9 @@ tf_cuda_library(
],
deps = [
":c_api",
":context_interface",
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@ -105,24 +133,52 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
],
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "operation_interface",
hdrs = ["operation_interface.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "context_interface",
hdrs = ["context_interface.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
@ -167,6 +223,8 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/strings",
],
)
@ -193,6 +251,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
@ -202,8 +261,12 @@ tf_cuda_library(
name = "c_api_experimental",
srcs = [
"c_api_experimental.cc",
"c_api_unified_experimental.cc",
],
hdrs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
],
hdrs = ["c_api_experimental.h"],
copts = tf_copts() + tfe_xla_copts(),
visibility = ["//visibility:public"],
deps = select({
@ -219,6 +282,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
@ -229,6 +293,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:variant",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -240,7 +305,6 @@ tf_cuda_library(
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
@ -251,8 +315,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
@ -282,6 +344,51 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_unified_experimental_test",
size = "small",
srcs = [
"c_api_unified_experimental_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
srcs = [
"custom_device_test.cc",
],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tape",
hdrs = ["tape.h"],
@ -294,10 +401,38 @@ cc_library(
filegroup(
name = "headers",
srcs = ["c_api.h"],
srcs = [
"c_api.h",
"c_api_experimental.h",
"dlpack.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@dlpack",
],
alwayslink = 1,
)
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite)
@ -311,6 +446,7 @@ filegroup(
exclude = [
"c_api_experimental.cc",
"*test*",
"*dlpack*",
],
),
visibility = ["//visibility:public"],

File diff suppressed because it is too large Load Diff

View File

@ -206,14 +206,14 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
// error and nullptr is returned. This function can block till the operation
// that produces `handle` has completed.
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status);
TFE_TensorHandle* h, TF_Status* status);
// Deletes `debug_info`.
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info);
// Returns the number of dimensions used to represent the tensor on its device.
// The number of dimensions used to reprensent the tensor on device can be
// The number of dimensions used to represent the tensor on device can be
// different from the number returned by TFE_TensorHandleNumDims.
// The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/jit/xla_device.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@ -28,19 +28,22 @@ using tensorflow::string;
namespace {
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
TF_Status* status) {
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
tensorflow::Status* status) {
std::vector<int64> shape;
int rank = TFE_TensorHandleNumDims(handle, status);
if (TF_GetCode(status) != TF_OK) {
int rank = -1;
*status = handle.NumDims(&rank);
if (!status->ok()) {
return shape;
}
shape.reserve(rank);
for (int i = 0; i < rank; ++i) {
shape.push_back(TFE_TensorHandleDim(handle, i, status));
if (TF_GetCode(status) != TF_OK) {
tensorflow::int64 dim;
*status = handle.Dim(i, &dim);
if (!status->ok()) {
return shape;
}
shape.push_back(dim);
}
return shape;
}
@ -50,19 +53,19 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status) {
TFE_TensorHandle* h, TF_Status* status) {
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
const tensorflow::Tensor* tensor;
status->status = handle->handle->Tensor(&tensor);
if (TF_GetCode(status) != TF_OK) {
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle->handle->device();
auto* device = absl::get<tensorflow::Device*>(handle->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
dynamic_cast<tensorflow::XlaDevice*>(device);
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
if (xla_device != nullptr) {
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
@ -72,7 +75,8 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
std::vector<int64> shape_to_log =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
@ -138,8 +142,8 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
if (TF_GetCode(status) != TF_OK) {
std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);

View File

@ -18,103 +18,39 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
op_to_reset->operation->Clear();
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle);
}
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
return profiler->profiler->Status().ok();
}
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
TF_Status* status) {
string content;
status->status = profiler->profiler->SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
buf->data = data;
buf->length = content.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
void TFE_StartProfilerServer(int port) {
// Release child thread intentionally. The child thread can be terminated by
// terminating the main thread.
tensorflow::StartProfilerServer(port).release();
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
}
bool TFE_ProfilerClientStartTracing(const char* service_addr,
const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::client::StartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
result->length = content.length();
result->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
tensorflow::Set_TF_Status_from_Status(status, s);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(false);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
@ -544,7 +480,9 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
ctx->context->SetThreadLocalMirroringPolicy(
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy));
}
@ -553,8 +491,9 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) {
return static_cast<TFE_ContextMirroringPolicy>(
ctx->context->GetMirroringPolicy());
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
}
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
@ -562,6 +501,10 @@ void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
options->lazy_remote_inputs_copy = lazy_copy;
}
void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
options->use_tfrt = use_tfrt;
}
TFE_CancellationManager* TFE_NewCancellationManager() {
return new TFE_CancellationManager;
}
@ -584,8 +527,11 @@ void TFE_DeleteCancellationManager(
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
op->operation.SetCancellationManager(
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
operation->SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = tensorflow::Status::OK();
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
@ -608,9 +554,55 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
}
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
ctx->context->SetExecutorForThread(executor->executor());
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetExecutorForThread(executor->executor());
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor());
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return new TFE_Executor(&context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
TF_Status* status) {
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
return;
}
string str = function_def->SerializeAsString();
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}

View File

@ -22,38 +22,17 @@ limitations under the License.
extern "C" {
#endif
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
// is for performance optimization by reusing an exiting unused op rather than
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
// does not set the device name. If it's not `NULL`, then it attempts to parse
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than separately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
// A profiler which will start profiling when creating the object and will stop
// when the object is destroyed. It will profile all operations run under the
// given TFE_Context. Multiple instance of it can be created, but at most one
// of them will profile for each TFE_Context.
// Thread-safety: TFE_Profiler is thread-safe.
typedef struct TFE_Profiler TFE_Profiler;
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
// The output string is a binary string of tensorflow.tpu.Trace. User can write
// the string to file for offline analysis by tensorboard.
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
TF_Buffer* buf,
TF_Status* status);
// Start a profiler grpc server which listens to specified port. It will start
// the server on its own thread. It can be shutdown by terminating tensorflow.
// It can be used in both Eager mode and graph mode. Creating multiple profiler
// server is allowed. The service defined in
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
const char* raw_device_name,
TF_Status* status);
// Enables only graph collection in RunMetadata on the functions executed from
// this context.
@ -63,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
// this context.
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// profiling and save the result into logdir which can be visualized by
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
// include_dataset_opts to false to profile longer traces. It will block the
// caller thread until receives tracing result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
TF_Status* status);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// monitoring and return the result in a string. It will block the
// caller thread until receiving the monitoring result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer* result, TF_Status* status);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
@ -340,6 +296,10 @@ TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TFE_ContextOptions*, bool lazy_copy);
// Sets whether to use TFRT
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
// -----------------------------------------------------------------------------
// Cancellation APIs.
@ -426,6 +386,150 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Sync pending nodes in local executors (including the context default executor
// and thread executors) and streaming requests to remote executors, and get the
// combined status.
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status);
// If the TensorHandle is copied to another device as part of an op execution,
// the copy is destroyed after the op has executed. Enabling implicit mirroring
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
TFE_TensorHandle*, TF_Status*);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
// for a GPU tensor this will return a pointer to GPU memory). The pointer is
// only guaranteed to be valid until TFE_DeleteTensorHandle is called on this
// TensorHandle. Only supports POD data types.
TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*,
TF_Status*);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. Returns the size in
// bytes of the memory pointed to by the device pointer returned above.
TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*,
TF_Status*);
// Creates a new TensorHandle from memory residing in device_name. Takes
// ownership of the memory, and will call deleter to release it after TF
// no longer needs it or in case of error.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims,
int num_dims, void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status);
// Retrieves the address space (i.e. job, replia, task) of the local host and
// saves it in the buffer.
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
// APIs for generically dealing with op attributes (e.g. when forwarding them
// through custom device implementations).
//
// TODO(allenl): Currently these are black boxes, but we should have some way to
// inspect values. This would let people e.g. copy over most attributes and then
// modify some based on their values.
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a struct with a reference to information about attributes of `op`.
//
// The `attrs` struct does not own any memory, and `op` must outlive it.
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
// containing the op name and a map of its attributes.
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
TF_Buffer* buf,
TF_Status* status);
// Set an op's attribute from a serialized AttrValue protocol buffer.
//
// Analogous to TF_SetAttrValueProto for building graph operations.
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
const char* attr_name,
const void* proto,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 2
// Struct to be filled in
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info);
// Method to execute an operation.
void (*execute)(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
} TFE_CustomDevice;
// Registers a custom device for use with eager execution.
//
// Eager operations may be placed on this device, e.g. `with
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
//
// The custom device defines copy operations for moving TensorHandles on and
// off, and an an execution operation for named operations. Often execution will
// simply wrap op execution on one or more physical devices.
//
// device_info is an opaque caller-defined type stored with the custom device
// which is passed to the functions referenced in the TFE_CustomDevice struct
// `device` (execute, delete_device, etc.). It can for example contain the
// names of wrapped devices.
//
// There are currently no graph semantics implemented for registered custom
// devices, so executing tf.functions which contain operations placed on custom
// devices will fail.
//
// `device_name` must not name an existing physical or custom device. It must
// follow the format:
//
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
//
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
// the device is not usable. In case of a bad status, `device.delete_device` is
// still called on `device_info` (i.e. the caller does not retain ownership).
//
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -21,12 +21,11 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/trace_events.pb.h"
using tensorflow::string;
@ -39,88 +38,6 @@ static bool HasSubstr(absl::string_view base, absl::string_view substr) {
return ok;
}
void ExecuteWithProfiling(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_Profiler* profiler = TFE_NewProfiler();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
// Run op on GPU if it is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
}
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Buffer* profiler_result = TF_NewBuffer();
if (async) {
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
}
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
TFE_DeleteProfiler(profiler);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
profiler::Trace profile_proto;
EXPECT_TRUE(profile_proto.ParseFromString(
{reinterpret_cast<const char*>(profiler_result->data),
profiler_result->length}));
string profile_proto_str = profile_proto.DebugString();
#ifndef TENSORFLOW_USE_ROCM
// TODO(rocm): enable once GPU profiling is supported in ROCm mode
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
}
#endif
// "/host:CPU" is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
TF_DeleteBuffer(profiler_result);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); }
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
TEST(CAPI, MultipleProfilerSession) {
TFE_Profiler* profiler1 = TFE_NewProfiler();
EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
TFE_Profiler* profiler2 = TFE_NewProfiler();
EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
TFE_DeleteProfiler(profiler1);
TFE_DeleteProfiler(profiler2);
}
TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus();
auto* counter =
@ -495,5 +412,55 @@ void Executor_MatMul_CPU(bool async) {
TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); }
TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); }
void Deleter(void* data, size_t unused, void* tensor_handle) {
TFE_DeleteTensorHandle(static_cast<TFE_TensorHandle*>(tensor_handle));
}
TEST(CAPI, TensorHandleOnDeviceMemory) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float* m_float = static_cast<float*>(TF_TensorData(m_data));
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_devices = TF_DeviceListCount(devices);
for (int d = 0; d < num_devices; ++d) {
const char* name = TF_DeviceListName(devices, d, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* copy = TFE_TensorHandleCopyToDevice(m, ctx, name, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
void* data = TFE_TensorHandleDevicePointer(copy, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
size_t size = TFE_TensorHandleDeviceMemorySize(copy, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int64_t dims[] = {2, 2};
TFE_TensorHandle* copy_aliased = TFE_NewTensorHandleFromDeviceMemory(
ctx, name, TF_FLOAT, dims, 2, data, size, &Deleter, copy, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* on_host =
TFE_TensorHandleCopyToDevice(copy_aliased, ctx, "CPU:0", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* resolved = TFE_TensorHandleResolve(on_host, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const float* resolved_data =
static_cast<const float*>(TF_TensorData(resolved));
EXPECT_EQ(0, memcmp(m_float, resolved_data, 4 * sizeof(float)));
TF_DeleteTensor(resolved);
TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
TFE_DeleteTensorHandle(on_host);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(m_data);
TFE_DeleteTensorHandle(m);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
} // namespace
} // namespace tensorflow

View File

@ -1,58 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
bool is_function,
TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
if (op_to_reset) {
op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
return op_to_reset;
} else {
return new TFE_Op(ctx, name, is_function, types, inference_ctx);
}
};
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
return create_or_reset(true, nullptr);
}

View File

@ -27,27 +27,25 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/public/version.h"
struct TFE_ContextOptions {
@ -58,51 +56,29 @@ struct TFE_ContextOptions {
TFE_DEVICE_PLACEMENT_SILENT};
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
// If true, lazily copy the remote inputs of a function to the target devices.
bool lazy_remote_inputs_copy = false;
bool lazy_remote_inputs_copy = true;
// If true, use TFRT backend
bool use_tfrt = false;
};
// Wraps a pointer to a context implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying context object. Instead, call
// TFE_DeleteContext who calls Release() on the context pointer and deletes
// the TFE_Context structure.
struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_remote_inputs_copy,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
default_mirroring_policy),
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
rendezvous, custom_kernel_creator)) {}
~TFE_Context() {
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
// don't send RPCs and block in destructor.
context->WaitForAndCloseRemoteContexts();
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
context->Unref();
}
tensorflow::EagerContext* context;
tensorflow::AbstractContextInterface* context;
};
// Wraps a pointer to a tensor handle implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying handle object. Instead, call
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
// the TFE_TensorHandle structure.
struct TFE_TensorHandle {
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) {
tensorflow::TensorHandle* handle;
s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
if (!s->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(handle);
}
tensorflow::TensorHandle* handle;
tensorflow::AbstractTensorHandleInterface* handle;
};
struct TFE_TensorDebugInfo {
@ -113,45 +89,14 @@ struct TFE_TensorDebugInfo {
std::vector<tensorflow::int64> dev_dims;
};
struct TFE_OpInferenceContext {
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
: op_def(op_def) {}
const tensorflow::OpDef* op_def; // op definition from protobuf
int input_arg_idx = 0; // arg definition index for the next input to be added
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
};
// Wraps a pointer to an operation implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying operation object. Instead, call
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
// the TFE_Op structure.
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx)
: operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
void Reset(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* infer_ctx) {
operation.Reset(ctx->context, op, is_function, t, nullptr);
inference_ctx.reset(infer_ctx);
}
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
std::unique_ptr<tensorflow::ProfilerSession> profiler;
tensorflow::AbstractOperationInterface* operation;
};
struct TFE_MonitoringCounterCell {
@ -298,4 +243,17 @@ struct TFE_Executor {
tensorflow::EagerExecutor* unowned_executor;
};
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
const char* op_name)
: name(op_name), attributes(value) {}
const char* name;
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
@ -127,7 +129,7 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async) {
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -166,10 +168,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
TFE_OpSetDevice(matmul, task1_name, status);
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
}
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
@ -177,6 +183,15 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(matmul->operation);
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->Inputs()[1], remote_arg);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -213,9 +228,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
worker_server2.release();
}
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
TestRemoteExecuteSilentCopies(true, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {

View File

@ -17,15 +17,20 @@ limitations under the License.
#include <string.h>
#include <string>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
@ -363,34 +368,58 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
TensorHandleCopyBetweenTwoGPUDevices(true);
}
void TensorHandleSilentCopy(bool async) {
void TensorHandleSilentCopy(bool async,
TFE_ContextDevicePlacementPolicy global_policy,
TFE_ContextDevicePlacementPolicy thread_policy,
bool cpu_op) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
if (thread_policy != global_policy) {
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
}
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
if (cpu_op) {
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
} else {
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
}
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(matmul->operation);
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->Inputs()[0], arg0);
EXPECT_EQ(op->Inputs()[1], arg1);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
@ -404,57 +433,21 @@ void TensorHandleSilentCopy(bool async) {
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
void TensorHandleSilentCopyLocal(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
TFE_DEVICE_PLACEMENT_SILENT);
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
TEST(CAPI, TensorHandleSilentCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
TensorHandleSilentCopyLocal(true);
TEST(CAPI, TensorHandleSilentCopyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false);
}
void SetAndGetOpDevices(bool async) {
@ -496,40 +489,35 @@ TEST(CAPI, TensorHandleNullptr) {
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(t, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_name, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_name, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int dim = TFE_TensorHandleDim(h, 0, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(dim, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
}
TEST(CAPI, TensorHandleDevices) {
@ -590,6 +578,91 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
void ExecuteAdd(bool async, bool forward_input) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
// If a GPU exists, copy the handle to GPU so that we can exercise
// unprotecting a mirror.
std::string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* n_gpu =
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
TFE_DeleteTensorHandle(n);
n = n_gpu;
}
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
// Store pointer to raw buffer for validation of forwarding behaviour.
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
void* orig_ptr = TF_TensorData(orig);
TF_DeleteTensor(orig);
TFE_Op* add_op = AddOp(ctx, n, m);
std::string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (forward_input) {
TFE_DeleteTensorHandle(n);
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (!forward_input) {
TFE_DeleteTensorHandle(n);
}
TFE_DeleteOp(add_op);
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
if (forward_input || async) {
EXPECT_EQ(orig_ptr, TF_TensorData(t));
} else {
EXPECT_NE(orig_ptr, TF_TensorData(t));
}
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1228,6 +1301,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_DeleteTensorHandle(h_shares_tensor);
}
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
operation->Attrs().FillAttrValueMap(&attr_values);
return attr_values;
}
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1244,8 +1325,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TFE_OpAddInput(minOp, axis, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1284,8 +1364,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
TFE_OpAddInputList(concatOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1325,8 +1404,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
@ -1362,15 +1440,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->inference_ctx);
CHECK(concatOp->operation->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
EXPECT_FALSE(concatOp->operation->OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
EXPECT_EQ(attr_values.find("T"), attr_values.end());
EXPECT_EQ(attr_values.find("N"), attr_values.end());
@ -1457,4 +1535,88 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_ATTR_SHAPE,
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(copy_op->operation);
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(copy_op);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpAttrsSerialize) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
serialized_attr_values->length));
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
ASSERT_EQ(tensorflow::DT_INT64,
name_and_attrs.attr().find("dtype")->second.type());
TF_DeleteBuffer(serialized_attr_values);
TFE_Op* var_op_2 = TFE_NewOp(ctx, "VarHandleOp", status);
string serialized_dtype;
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
&serialized_dtype));
TFE_OpSetAttrValueProto(
var_op_2, "dtype",
reinterpret_cast<const void*>(serialized_dtype.c_str()),
serialized_dtype.length(), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(var_op_2->operation);
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(var_op_2);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
return th;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();

View File

@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);

View File

@ -0,0 +1,261 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::string;
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);
struct TF_ExecutionContext {
explicit TF_ExecutionContext() {}
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
ExecuteOperation execution_callback;
};
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
};
struct TF_AbstractOp {
string op_type;
string op_name;
};
TF_ExecutionContext* TF_NewExecutionContext() {
return new TF_ExecutionContext();
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
TF_AbstractOp* TF_NewAbstractOp() {
TF_AbstractOp* op = new TF_AbstractOp;
return op;
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
struct TF_GraphContext {
TF_Graph* graph;
// TODO(srbs): Handle captures.
};
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
auto ctx = new TF_GraphContext;
ctx->graph = g;
return ctx;
}
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
struct TF_GraphTensor {
TF_Output output;
TF_GraphContext* ctx;
};
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
TF_Status* s) {
TF_GraphTensor* t = new TF_GraphTensor;
t->output = output;
t->ctx = ctx;
return t;
}
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
return t->output;
}
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TFE_TensorHandle*>(at->t);
}
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s) {
at->t = t;
}
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
string msg = absl::StrCat("Not an graph tensor handle.");
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TF_GraphTensor*>(at->t);
}
bool IsEagerTensor(const TF_AbstractTensor* const t) {
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
}
struct TF_OutputList {
std::vector<TF_AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
o->expected_num_outputs = num_outputs;
}
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
}
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
auto* tfe_op =
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
TFE_DeleteOp(tfe_op);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = TF_NewAbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
TF_Graph* g = graph_ctx->graph;
auto* tf_opdesc =
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != graph_ctx) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = TF_NewAbstractTensor();
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
t->t = output_t;
o->outputs.push_back(t);
}
}
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context,
TF_Status* s) {
context->ctx = eager_context;
context->execution_callback = &ExecuteOperationEager;
}
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status* s) {
context->ctx = graph_context;
context->execution_callback = &ExecuteOperationGraph;
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
op->op_type = op_type;
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
op->op_name = op_name;
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
}

View File

@ -0,0 +1,119 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
// -----------------------------------------------------------------------------
// Core APIs
// -----------------------------------------------------------------------------
// A TF_ExecutionContext stores knowledge about how to execute an operation.
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
// of gradient tapes, etc.
typedef struct TF_ExecutionContext TF_ExecutionContext;
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
// type of eager and graph tensors.
typedef struct TF_AbstractTensor TF_AbstractTensor;
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
TF_ExecutionContext* TF_NewExecutionContext();
void TF_DeleteExecutionContext(TF_ExecutionContext*);
TF_AbstractOp* TF_NewAbstractOp();
void TF_DeleteAbstractOp(TF_AbstractOp*);
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// -----------------------------------------------------------------------------
// APIs for Eager and graph modes
// -----------------------------------------------------------------------------
// Keeps track of the current graph and other state e.g. captures etc.
typedef struct TF_GraphContext TF_GraphContext;
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
void TF_DeleteGraphContext(TF_GraphContext*);
// `eager_context` must outlive `context`.
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context, TF_Status*);
// `graph_context` must outlive `context`.
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s);
// `op_name` must outlive `op`.
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s);
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
typedef struct TF_GraphTensor TF_GraphTensor;
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
TF_Status* s);
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
void TF_DeleteGraphTensor(TF_GraphTensor* t);
// `t` must outlive `at`.
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
// `t` must outlive `at`.
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s);
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
// which happens to be active.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_

View File

@ -0,0 +1,204 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include <string.h>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;
namespace tensorflow {
namespace {
TEST(UnifedCAPI, TestBasicEager) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TFE_DeleteTensorHandle(t);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
float* result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 4.0);
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TFE_DeleteTensorHandle(result_t);
TF_DeleteOutputList(o);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
}
TEST(UnifedCAPI, TestBasicGraph) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Enter a graph context.
TF_Graph* g = TF_NewGraph();
TF_GraphContext* graph_context = TF_NewGraphContext(g);
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
auto* operation = TF_FinishOperation(placeholder_op, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output placeholder_t = {operation, 0};
TF_GraphTensor* graph_t =
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* t = TF_NewAbstractTensor();
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {t, t};
TF_OutputList* o = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(t);
TF_DeleteGraphTensor(graph_t);
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TF_GraphTensor* result_graph_tensor =
TF_AbstractTensorGetGraphTensor(result, status.get());
TF_DeleteAbstractTensor(result);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output result_output =
TF_GraphTensorToOutput(result_graph_tensor, status.get());
TF_DeleteGraphTensor(result_graph_tensor);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
string fn_name = "double";
TF_Function* f = TF_GraphToFunction(
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
nullptr, nullptr, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an eager context to run the function.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Build the abstract op to run the function.
TFE_ContextAddFunction(eager_ctx, f, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOp* fn_op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
float* f_value = static_cast<float*>(TF_TensorData(f_t));
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(o);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TFE_DeleteTensorHandle(input_eager);
TF_DeleteAbstractTensor(final_result);
TFE_DeleteTensorHandle(final);
TF_DeleteTensor(f_t);
TF_DeleteFunction(f);
TF_DeleteGraphContext(graph_context);
TF_DeleteGraph(g);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#include <vector>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
namespace tensorflow {
// Abstract interface to a context.
//
// A context is responsible for creating key objects such as Tensors,
// TensorHandles & Operations.
class AbstractContextInterface {
public:
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0;
virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0;
virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0;
virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0;
virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0;
virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0;
virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
// Tensor creation functions
virtual AbstractTensorInterface* CreateInt64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateUint64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateInt32Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateFloatTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateDoubleTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateHalfTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateStringTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateBoolTensor(
absl::Span<const int64> dim_sizes) = 0;
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) = 0;
// Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
protected:
virtual ~AbstractContextInterface() {}
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_

View File

@ -0,0 +1,398 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* context = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
ASSERT_FALSE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
TFE_OpSetDevice(matmul.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteTensorHandle(hdevice);
TFE_DeleteContext(context);
}
TEST(CUSTOM_DEVICE, ResetOperation) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string(custom_device_name));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpReset(reused_op.get(), "Identity",
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, MakeVariable) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto value_cleaner = tensorflow::gtl::MakeCleanup(
[var_value]() { TFE_DeleteTensorHandle(var_value); });
ASSERT_EQ(tensorflow::string(name),
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,330 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
// Managing context for the DLManagedTensor, will manage the lifetime of
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
// original framework of destruction, and this context will be deleted also.
struct TfDlManagedTensorCtx {
TensorReference reference;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
};
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
return tensor;
}
// Deleter for DLManagedTensor
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
owner->reference.Unref();
delete owner;
}
// Converts TF_DATAType to DLPack data type.
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT:
case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_INT8:
case TF_DataType::TF_INT16:
case TF_DataType::TF_INT32:
case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break;
}
return dtype;
}
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;
int device_id = 0;
if (parsed_name.has_id) {
device_id = parsed_name.id;
}
ctx.device_id = device_id;
if (device_type == "CPU") {
ctx.device_type = DLDeviceType::kDLCPU;
} else if (device_type == "GPU") {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for dlpack");
}
return ctx;
}
// Converts DLContext to TF device name.
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return "CPU:0";
case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id);
default:
return absl::nullopt;
}
}
// Converts DLPack data type to TF_DATATYPE.
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_DataType* tf_dtype) {
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_UINT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_UINT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_UINT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_UINT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_INT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_INT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_INT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_INT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_HALF;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_FLOAT;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_DOUBLE;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
dtype.bits);
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_BFLOAT16;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits: ", dtype.bits);
}
break;
default:
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
dtype.code);
}
}
// Wraps the deleter function of DLManagedTensor to match the function signature
// TFE_NewTensorHandleFromDeviceMemory.
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
}
// Checks whether the stride array matches the layout of compact, row-majored
// data.
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
}
}
return true;
}
} // namespace
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
if (dlMTensor->deleter != nullptr) {
dlMTensor->deleter(dlMTensor);
}
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
TFE_Context* ctx) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
DeviceNameFromDlContext(dl_tensor->ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype;
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
if (!s.ok()) {
status->status = std::move(s);
return nullptr;
}
int num_dims = dl_tensor->ndim;
const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data;
size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
}
if (dl_tensor->strides != nullptr &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid strides array from DLPack");
return nullptr;
}
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
return handle;
}
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
#define TENSORFLOW_C_EAGER_DLPACK_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
// PyCapsule name for DLPack Tensor
const char* const kDlTensorCapsuleName = "dltensor";
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
// void* for further PyCapsule construction.
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
TF_Status* status);
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status,
TFE_Context* ctx);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_DLPACK_H_

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#include "absl/types/span.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow {
// Abstract interface to an operation.
class AbstractOperationInterface {
public:
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
virtual void Clear() = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
virtual const string& DeviceName() const = 0;
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
virtual Status AddInputList(
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status SetAttrString(const char* attr_name, const char* data,
size_t length) = 0;
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0;
virtual Status SetAttrFunction(const char* attr_name,
const AbstractOperationInterface* value) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
virtual Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) = 0;
virtual Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) = 0;
virtual Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) = 0;
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) = 0;
virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) = 0;
virtual Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperationInterface*> values) = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
protected:
virtual ~AbstractOperationInterface() {}
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_

View File

@ -284,7 +284,7 @@ class ForwardAccumulator {
// Temporarily push or pop transient state for this accumulator.
//
// Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators
// temporarily reset its state. Without pushing and popping, accumulators
// ignore operations executed as a direct result of their own jvp
// computations.
void PushState() { call_state_.emplace(nullptr, false); }

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to a TensorHandle.
//
// A TensorHandle is management class around a Tensor which may track additional
// metadata and synchronization.
//
// This allows us to hide concrete implementations of TensorHandle from header
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
public:
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
// Returns number of dimensions.
virtual Status NumDims(int* num_dims) const = 0;
// Returns number of elements across all dimensions.
virtual Status NumElements(int64* num_elements) const = 0;
// Returns size of specified dimension
virtual Status Dim(int dim_index, int64* dim) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
// Maintain mirror tensors for any implicit copies to local devices. This
// setting is offered on a per tensor handle basis to avoid potential memory
// over utilization due to holding on to mirrors as well as the original
// tensor. Note this setting overrides the context mirroring policy whereby if
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
// this tensor.
virtual void EnableImplicitMirroring() = 0;
protected:
virtual ~AbstractTensorHandleInterface() {}
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -18,37 +18,28 @@ cc_library(
],
)
# Core TensorFlow depends on this, this will be included in main library
cc_library(
name = "filesystem_interface_impl",
srcs = ["filesystem_interface.cc"],
hdrs = ["filesystem_interface.h"],
deps = [
":modular_filesystem",
"//tensorflow/c:tf_file_statistics",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:stringpiece",
],
alwayslink = 1,
)
# Core TensorFlow depends on this, will be included in main library
cc_library(
name = "modular_filesystem",
srcs = ["modular_filesystem.cc"],
hdrs = ["modular_filesystem.h"],
srcs = [
"modular_filesystem.cc",
"modular_filesystem_registration.cc",
],
hdrs = [
"modular_filesystem.h",
"modular_filesystem_registration.h",
],
# TODO(mihaimaruseac): Visibility should be more restrictive once we
# convert to modular filesystems everywhere
visibility = ["//visibility:public"],
deps = [
":filesystem_interface",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
],
)
@ -63,16 +54,12 @@ tf_cc_test(
"notap", # b/139060984, requires implementing modular support for Google filesystem
],
deps = [
":filesystem_interface_impl",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
":modular_filesystem",
"//tensorflow/core:framework_internal",
"//tensorflow/core/lib/io:path",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:error",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:str_util",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:test",
],
)

View File

@ -1,366 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/util/ptr_util.h"
/// This translation unit is linked in core TensorFlow and provides the
/// functionality needed for plugin registration to check ABI/API compatibility,
/// to ensure required methods are present, to ensure plugins are not allowed to
/// change functionality after being loaded and to register the filesystems
/// provided by a plugin. Consult the header file for more information about
/// how this is achieved.
namespace tensorflow {
namespace {
// Checks if the plugin and core ABI numbers match, filling in `status`.
//
// If the numbers don't match, plugin cannot be loaded.
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
TF_Status* status) {
if (pluginABI != coreABI) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded.")
.c_str());
return false;
}
return true;
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
static bool CheckABI(
int plugin_filesystem_ops_ABI,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_ABI,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_ABI,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
"filesystem", status))
return false;
if (plugin_random_access_file_ops != nullptr &&
!CheckABIHelper(plugin_random_access_file_ops_ABI,
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
status))
return false;
if (plugin_writable_file_ops != nullptr &&
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
"writable file", status))
return false;
if (plugin_read_only_memory_region_ops != nullptr &&
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region", status))
return false;
return true;
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void CheckAPI(
int plugin_filesystem_ops_API,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_API,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_API,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_API) {
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
"filesystem");
if (plugin_random_access_file_ops != nullptr)
CheckAPIHelper(plugin_random_access_file_ops_API,
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
if (plugin_writable_file_ops != nullptr)
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (plugin_read_only_memory_region_ops != nullptr)
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
TF_READ_ONLY_MEMORY_REGION_OPS_API,
"read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
if (ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without operations");
return false;
}
if (ops->init == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `init` operation");
return false;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation");
return false;
}
return true;
}
// Validates the random access file operations supplied by the plugin.
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"random access files");
return false;
}
return true;
}
// Validates the writable file operations supplied by the plugin.
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
if (ops == nullptr) {
// We allow read-only filesystems
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"writable files");
return false;
}
return true;
}
// Validates the read only memory region operations given by the plugin.
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// read only memory region support is always optional
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"read only memory regions");
return false;
}
if (ops->data == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `data` operation on "
"read only memory regions");
return false;
}
if (ops->length == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `length` operation on "
"read only memory regions");
return false;
}
return true;
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
// each individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static bool Validate(
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
plugin_random_access_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
return false;
}
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
plugin_writable_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
return false;
}
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
plugin_read_only_memory_region_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return false;
}
return true;
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = sizeof(T);
if (plugin_size < copy_size) {
copy_size = plugin_size;
}
auto core_ops = tensorflow::MakeUnique<T>();
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
return core_ops;
}
} // namespace
} // namespace tensorflow
void RegisterFilesystemPlugin(
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
int plugin_random_access_file_ops_API,
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
int plugin_read_only_memory_region_ops_ABI,
int plugin_read_only_memory_region_ops_API,
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (scheme == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"`scheme` argument must not be `nullptr`.");
return;
}
// ABI numbers must match exactly for plugin to be loaded
if (!tensorflow::CheckABI(
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_ABI, status)) {
return;
}
// API numbers should match but mismatch doesn't block plugin load
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
plugin_random_access_file_ops_API,
plugin_writable_file_ops, plugin_writable_file_ops_API,
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_API);
// Plugin can only be loaded if all supplied ops are valid
if (!tensorflow::Validate(plugin_filesystem_ops,
plugin_random_access_file_ops,
plugin_writable_file_ops,
plugin_read_only_memory_region_ops, status)) {
return;
}
// Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
plugin_filesystem_ops, plugin_filesystem_ops_size);
auto core_random_access_file_ops =
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
plugin_writable_file_ops, plugin_writable_file_ops_size);
auto core_read_only_memory_region_ops =
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_size);
// Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
core_filesystem_ops->init(filesystem.get(), status);
if (!status->status.ok()) {
core_filesystem_ops->cleanup(filesystem.get());
return;
}
// Register new filesystem
status->status = tensorflow::Env::Default()->RegisterFileSystem(
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
}

View File

@ -56,7 +56,7 @@ extern "C" {
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
/// pointed to by the `void*` members is always owned by the plugin. The plugin
/// will provide functions to call to allocate and deallocate this data (see
/// next section) and core TensorFlow ensures to call these at the proper time.
/// next sections) and core TensorFlow ensures to call these at the proper time.
///
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
/// TensorFlow will never touch the `void*` wrapped by these structures, except
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
/// If `statuses` is not null, plugins must fill each element with detailed
/// status for each file, as if calling `path_exists` on each one. Core
/// TensorFlow initializes the `statuses` array and plugins must use
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
/// `TF_SetStatus` to set each element instead of directly assigning.
///
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
/// `path_exists`.
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
///
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// This function will be called by core TensorFlow to clean up all path
/// arguments for all other methods in the filesystem API.
///
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all children were returned.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
/// different than `TF_OK`, free any memory that might have been allocated for
/// `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all matches were returned.
/// * Might use any other error value for `status` to signal other errors.
@ -736,95 +748,132 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// SECTION 4. Plugin registration and initialization
/// ----------------------------------------------------------------------------
///
/// In this section we define two functions:
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
/// be called by core TensorFlow when the filesystem plugin is loaded;
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
/// In this section we define the API used by core TensorFlow to initialize a
/// filesystem provided by a plugin. That is, we define the following:
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
/// it will be called by core TensorFlow when the filesystem plugin is
/// loaded;
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
/// plugins and core TensorFlow about the operations provided and metadata;
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
/// collects information about all the file schemes that the plugin provides
/// support for, as well as about the plugin's memory handling routines;
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
/// their `TF_InitPlugin` to record the versioning information the plugins
/// are compiled against.
///
/// The `TF_InitPlugin` function is used by plugins to set up the data
/// structures that implement this interface, as presented in Section 2.
///
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
/// 2. If the API numbers are mismatched, we warn the user and continue
/// loading the plugin.
/// 3. If any required operation is missing, we stop loading the plugin.
///
/// If all these checks succeed, we copy the plugin operations to a different
/// memory location so that core TensorFlow has the guarantee that they won't be
/// changed by plugins at a later time. Finally, we initialize the opaque
/// pointer of `TF_Filesystem` by calling the required `init` function of
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
/// structures that implement this interface, as presented in Section 2. In
/// order to not have plugin shared objects call back symbols defined in core
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
/// metadata and setting up all the supported operations and the URI schemes
/// that are supported).
// Initializes a TensorFlow plugin.
//
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
//
// Filesystem plugins can be loaded on demand by users via
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
// paths (although this has a security risk if two plugins register for the
// same filesystem and the malicious one loads before the legimitate one -
// but we consider this to be something that users should care about and
// manage themselves). In both of these cases, core TensorFlow looks for
// the `TF_InitPlugin` symbol and calls that function.
//
// A plugin is loaded only if this `status` is `TF_OK` after the call.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
/// This structure incorporates the operations defined in Section 2 and the
/// metadata defined in section 3, allowing plugins to define different ops
/// for different URI schemes.
///
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
/// must be "". The scheme must never be `nullptr`.
///
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
/// TensorFlow uses the information present in this to initialize filesystems
/// for the URI schemes that the plugin requests.
///
/// All pointers defined in this structure point to memory allocated by the DSO
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that a new type of file needs to be
/// supported, add the new ops and metadata at the end of the structure.
typedef struct TF_FilesystemPluginOps {
char* scheme;
int filesystem_ops_abi;
int filesystem_ops_api;
size_t filesystem_ops_size;
TF_FilesystemOps* filesystem_ops;
int random_access_file_ops_abi;
int random_access_file_ops_api;
size_t random_access_file_ops_size;
TF_RandomAccessFileOps* random_access_file_ops;
int writable_file_ops_abi;
int writable_file_ops_api;
size_t writable_file_ops_size;
TF_WritableFileOps* writable_file_ops;
int read_only_memory_region_ops_abi;
int read_only_memory_region_ops_api;
size_t read_only_memory_region_ops_size;
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
} TF_FilesystemPluginOps;
/// Registers a filesystem plugin so that core TensorFlow can use it.
/// This structure gathers together all the operations provided by the plugin.
///
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
///
/// Arguments (grouped by category):
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
/// * `..API`: API compatibility numbers (see Section 3.).
/// * `..Size`: Sizes of the operation tables (see Section 3.).
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
/// must be "". Must never be `nullptr`.
/// * `..Ops`: The function tables provided by the plugin. Owned by the
/// plugin, but core TensorFlow makes a copy of these.
/// * `status`: The output variable for representing success/failure.
/// Since memory that is allocated by the DSO gets transferred to core
/// TensorFlow, we need to provide a way for the allocation and deallocation to
/// match. This is why this structure also defines `plugin_memory_allocate` and
/// `plugin_memory_free` members.
///
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
/// `status` means that plugin failed to load properly and as such the
/// operations it provides cannot be used at all (i.e., core TensorFlow will
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
/// values).
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
int pluginReadOnlyMemoryRegionOpsAPI,
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
const TF_FilesystemOps* pluginFilesystemOps,
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
const TF_WritableFileOps* pluginWritableFileOps,
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
TF_Status* status);
/// All memory allocated by the plugin that will be owned by core TensorFlow
/// must be allocated using the allocator in this structure. Core TensorFlow
/// will use the deallocator to free this memory once it no longer needs it.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that new global operations must be
/// provided, add them at the end of the structure.
typedef struct TF_FilesystemPluginInfo {
size_t num_schemes;
TF_FilesystemPluginOps* ops;
void* (*plugin_memory_allocate)(size_t size);
void (*plugin_memory_free)(void* ptr);
} TF_FilesystemPluginInfo;
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
/// Plugins should prefer using this macro instead of a direct call.
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
RegisterFilesystemPlugin( \
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
pluginRandomAccessFileOps, pluginWritableFileOps, \
pluginReadOnlyMemoryRegionOps, status)
/// Convenience function for setting the versioning metadata.
///
/// The argument is guaranteed to not be `nullptr`.
///
/// We want this to be defined in the plugin's memory space and we guarantee
/// that core TensorFlow will never call this.
static inline void TF_SetFilesystemVersionMetadata(
TF_FilesystemPluginOps* ops) {
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
}
/// Initializes a TensorFlow plugin.
///
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
///
/// Filesystem plugins can be loaded on demand by users via
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
/// paths (although this has a security risk if two plugins register for the
/// same filesystem and the malicious one loads before the legimitate one -
/// but we consider this to be something that users should care about and
/// manage themselves). In both of these cases, core TensorFlow looks for
/// the `TF_InitPlugin` symbol and calls this function.
///
/// For every filesystem URI scheme that this plugin supports, the plugin must
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
/// `TF_SetFilesystemVersionMetadata` for that entry.
///
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
/// freed in a compatible way.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
#ifdef __cplusplus
} // end extern "C"

View File

@ -18,11 +18,10 @@ limitations under the License.
#include <string>
#include <utility>
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/util/ptr_util.h"
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
@ -165,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dir);
char** children;
// Note that `children` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** children = nullptr;
const int num_children =
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
plugin_status.get());
if (num_children >= 0) {
for (int i = 0; i < num_children; i++) {
result->push_back(std::string(children[i]));
free(children[i]);
plugin_memory_free_(children[i]);
}
free(children);
plugin_memory_free_(children);
}
return StatusFromTF_Status(plugin_status.get());
@ -186,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
char** matches;
// Note that `matches` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** matches = nullptr;
const int num_matches = ops_->get_matching_paths(
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
if (num_matches >= 0) {
for (int i = 0; i < num_matches; i++) {
result->push_back(std::string(matches[i]));
free(matches[i]);
plugin_memory_free_(matches[i]);
}
free(matches);
plugin_memory_free_(matches);
}
return StatusFromTF_Status(plugin_status.get());
@ -215,9 +218,24 @@ Status ModularFileSystem::DeleteFile(const std::string& fname) {
Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
int64* undeleted_files,
int64* undeleted_dirs) {
// TODO(mihaimaruseac): Implementation to come in a new change
return Status(error::UNIMPLEMENTED,
"Modular filesystem stub not implemented yet");
if (undeleted_files == nullptr || undeleted_dirs == nullptr)
return errors::FailedPrecondition(
"DeleteRecursively must not be called with `undeleted_files` or "
"`undeleted_dirs` set to NULL");
if (ops_->delete_recursively == nullptr)
return FileSystem::DeleteRecursively(dirname, undeleted_files,
undeleted_dirs);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dirname);
uint64_t plugin_undeleted_files, plugin_undeleted_dirs;
ops_->delete_recursively(filesystem_.get(), translated_name.c_str(),
&plugin_undeleted_files, &plugin_undeleted_dirs,
plugin_status.get());
*undeleted_files = plugin_undeleted_files;
*undeleted_dirs = plugin_undeleted_dirs;
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteDir(const std::string& dirname) {
@ -233,9 +251,14 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname) {
}
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) {
// TODO(mihaimaruseac): Implementation to come in a new change
return Status(error::UNIMPLEMENTED,
"Modular filesystem stub not implemented yet");
if (ops_->recursively_create_dir == nullptr)
return FileSystem::RecursivelyCreateDir(dirname);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dirname);
ops_->recursively_create_dir(filesystem_.get(), translated_name.c_str(),
plugin_status.get());
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CreateDir(const std::string& dirname) {
@ -324,8 +347,8 @@ Status ModularFileSystem::CopyFile(const std::string& src,
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
const std::string& translated_src = TranslateName(src);
const std::string& translated_target = TranslateName(target);
std::string translated_src = TranslateName(src);
std::string translated_target = TranslateName(target);
ops_->copy_file(filesystem_.get(), translated_src.c_str(),
translated_target.c_str(), plugin_status.get());
return StatusFromTF_Status(plugin_status.get());
@ -338,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
std::string ret(p);
free(p);
// Since `p` is allocated by plugin, free it using plugin's method.
plugin_memory_free_(p);
return ret;
}
@ -415,4 +439,26 @@ Status ModularWritableFile::Tell(int64* position) {
return StatusFromTF_Status(plugin_status.get());
}
Status RegisterFilesystemPlugin(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Do the actual registration
return filesystem_registration::RegisterFilesystemPluginImpl(&info);
}
} // namespace tensorflow

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// TODO(b/143949615): After all filesystems are converted, this file will be
// moved to core/platform, and this class can become a singleton and replace the
// need for `Env::Default()`. At that time, we might decide to remove the need
// for `Env::Default()` altoghether, but that's a different project, not in
// for `Env::Default()` altogether, but that's a different project, not in
// scope for now. I'm just mentioning this here as that transition will mean
// removal of the registration part from `Env` and adding it here instead: we
// will need tables to hold for each scheme the function tables that implement
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops)
read_only_memory_region_ops,
std::function<void*(size_t)> plugin_memory_allocate,
std::function<void(void*)> plugin_memory_free)
: filesystem_(std::move(filesystem)),
ops_(std::move(filesystem_ops)),
random_access_file_ops_(std::move(random_access_file_ops)),
writable_file_ops_(std::move(writable_file_ops)),
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
plugin_memory_free_(std::move(plugin_memory_free)) {}
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops_;
std::function<void*(size_t)> plugin_memory_allocate_;
std::function<void(void*)> plugin_memory_free_;
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
};
@ -156,6 +162,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
};
// Registers a filesystem plugin so that core TensorFlow can use it.
Status RegisterFilesystemPlugin(const std::string& dso_path);
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_

View File

@ -0,0 +1,327 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
// Checks that all schemes provided by a plugin are valid.
// TODO(mihaimaruseac): More validation could be done here, based on supported
// charset, maximum length, etc. Punting it for later.
static Status ValidateScheme(const char* scheme) {
if (scheme == nullptr)
return errors::InvalidArgument(
"Attempted to register filesystem with `nullptr` URI scheme");
return Status::OK();
}
// Checks if the plugin and core ABI numbers match.
//
// If the numbers don't match, plugin cannot be loaded.
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
if (pluginABI != coreABI)
return errors::FailedPrecondition(
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded."));
return Status::OK();
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABI(int, int, StringPiece)`.
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
if (ops->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
TF_RANDOM_ACCESS_FILE_OPS_ABI,
"random access file"));
if (ops->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
if (ops->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region"));
return Status::OK();
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
if (ops->random_access_file_ops != nullptr)
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
"random access file");
if (ops->writable_file_ops != nullptr)
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (ops->read_only_memory_region_ops != nullptr)
CheckAPI(ops->read_only_memory_region_ops_api,
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static Status ValidateHelper(const TF_FilesystemOps* ops) {
if (ops == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without operations");
if (ops->init == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `init` operation");
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation");
return Status::OK();
}
// Validates the random access file operations supplied by the plugin.
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on random "
"access files");
return Status::OK();
}
// Validates the writable file operations supplied by the plugin.
static Status ValidateHelper(const TF_WritableFileOps* ops) {
if (ops == nullptr) {
// We allow read-only filesystems
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on writable "
"files");
return Status::OK();
}
// Validates the read only memory region operations given by the plugin.
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
if (ops == nullptr) {
// read only memory region support is always optional
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on read "
"only memory regions");
if (ops->data == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `data` operation on read only "
"memory regions");
if (ops->length == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `length` operation on read only "
"memory regions");
return Status::OK();
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
// individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
if (ops->filesystem_ops->new_random_access_file != nullptr &&
ops->random_access_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
if ((ops->filesystem_ops->new_writable_file != nullptr ||
ops->filesystem_ops->new_appendable_file != nullptr) &&
ops->writable_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
ops->read_only_memory_region_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return Status::OK();
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = std::min(plugin_size, sizeof(T));
auto core_ops = tensorflow::MakeUnique<T>();
memset(core_ops.get(), 0, sizeof(T));
memcpy(core_ops.get(), plugin_ops, copy_size);
return core_ops;
}
// Registers one filesystem from the plugin.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
int index) {
// Step 1: Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
info->ops[index].random_access_file_ops,
info->ops[index].random_access_file_ops_size);
auto core_writable_file_ops =
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
info->ops[index].writable_file_ops_size);
auto core_read_only_memory_region_ops =
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
info->ops[index].read_only_memory_region_ops,
info->ops[index].read_only_memory_region_ops_size);
// Step 2: Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
TF_Status* c_status = TF_NewStatus();
Status status = Status::OK();
core_filesystem_ops->init(filesystem.get(), c_status);
status = Status(c_status->status);
TF_DeleteStatus(c_status);
if (!status.ok()) return status;
// Step 3: Actual registration
return Env::Default()->RegisterFileSystem(
info->ops[index].scheme,
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops),
info->plugin_memory_allocate, info->plugin_memory_free));
}
// Registers filesystem at `index`, if plugin is providing valid information.
//
// Extracted to a separate function so that pointers inside `info` are freed
// by the caller regardless of whether validation/registration failed or not.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status ValidateAndRegisterFilesystems(
const TF_FilesystemPluginInfo* info, int index) {
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
return Status::OK();
}
// Ensures that the plugin provides the required memory management operations.
static Status ValidatePluginMemoryRoutines(
const TF_FilesystemPluginInfo* info) {
if (info->plugin_memory_allocate == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_allocate`");
if (info->plugin_memory_free == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_free`");
return Status::OK();
}
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(info));
// Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < info->num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(info, i));
info->plugin_memory_free(info->ops[i].scheme);
info->plugin_memory_free(info->ops[i].filesystem_ops);
info->plugin_memory_free(info->ops[i].random_access_file_ops);
info->plugin_memory_free(info->ops[i].writable_file_ops);
info->plugin_memory_free(info->ops[i].read_only_memory_region_ops);
}
info->plugin_memory_free(info->ops);
return status;
}
} // namespace filesystem_registration
} // namespace tensorflow

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace filesystem_registration {
// Implementation for filesystem registration
//
// Don't call this directly. Instead call `RegisterFilesystemPlugin`.
// Exposed only for static registration of local filesystems.
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info);
} // namespace filesystem_registration
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_

View File

@ -1,35 +1,62 @@
# Experimental posix filesystem plugin.
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
# Although this target results in a shared object that will be loaded at
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a
# `cc_library` for now and we will revisit in the future.
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all
# filesystems are converted and BUILD files are refactored to be modular).
# TODO(b/144585140): The helpers should be separated into a different BUILD target
# but doing that would result in symbols not being visible when loading plugin.
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
# This also has the unfortunate effect that both versions of copy_file get
# compiled, regardless of which one actually gets used!
# Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
tf_cc_shared_object(
name = "libposix_filesystem.so",
framework_so = [],
linkstatic = False,
visibility = ["//visibility:public"],
deps = [":posix_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "posix_filesystem",
srcs = [
"posix_filesystem.cc",
"posix_filesystem_helper.cc",
"posix_filesystem_helper.h",
"copy_file.h",
] + select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
name = "posix_filesystem_impl",
srcs = ["posix_filesystem.cc"],
hdrs = ["posix_filesystem.h"],
deps = [
":posix_filesystem_helper",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)
# Since building pip package and API tests require a filesystem, we provide a
# static registration target that they should link against.
cc_library(
name = "posix_filesystem_static",
srcs = ["posix_filesystem_static.cc"],
visibility = ["//visibility:public"],
deps = [
":posix_filesystem_impl",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/filesystem:modular_filesystem",
],
alwayslink = 1,
)
# Library implementing helper functionality, so that the above only contains
# the API implementation for modular filesystems.
cc_library(
name = "posix_filesystem_helper",
srcs = ["posix_filesystem_helper.cc"],
hdrs = ["posix_filesystem_helper.h"],
deps = [":copy_file"],
)
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
# Hence, this private library to select which implementation to use.
cc_library(
name = "copy_file",
srcs = select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
hdrs = ["copy_file.h"],
)

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
#include <dirent.h>
#include <errno.h>
#include <fcntl.h>
@ -24,15 +26,15 @@ limitations under the License.
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
@ -45,7 +47,9 @@ typedef struct PosixFile {
static void Cleanup(TF_RandomAccessFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
close(posix_file->fd);
free(const_cast<char*>(posix_file->filename));
// This would be safe to free using `free` directly as it is only opaque.
// However, it is better to be consistent everywhere.
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -100,7 +104,7 @@ typedef struct PosixFile {
static void Cleanup(TF_WritableFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
free(const_cast<char*>(posix_file->filename));
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -383,12 +387,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
if (num_entries < 0) {
TF_SetStatusFromIOError(status, errno, path);
} else {
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(dir_entries[i]->d_name);
free(dir_entries[i]);
plugin_memory_free(dir_entries[i]);
}
free(dir_entries);
plugin_memory_free(dir_entries);
}
return num_entries;
@ -396,48 +401,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
} // namespace tf_posix_filesystem
void TF_InitPlugin(TF_Status* status) {
TF_RandomAccessFileOps random_access_file_ops = {
tf_random_access_file::Cleanup,
tf_random_access_file::Read,
};
TF_WritableFileOps writable_file_ops = {
tf_writable_file::Cleanup, tf_writable_file::Append,
tf_writable_file::Tell, tf_writable_file::Flush,
tf_writable_file::Sync, tf_writable_file::Close,
};
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
tf_read_only_memory_region::Cleanup,
tf_read_only_memory_region::Data,
tf_read_only_memory_region::Length,
};
TF_FilesystemOps filesystem_ops = {
tf_posix_filesystem::Init,
tf_posix_filesystem::Cleanup,
tf_posix_filesystem::NewRandomAccessFile,
tf_posix_filesystem::NewWritableFile,
tf_posix_filesystem::NewAppendableFile,
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
tf_posix_filesystem::CreateDir,
/*recursively_create_dir=*/nullptr,
tf_posix_filesystem::DeleteFile,
tf_posix_filesystem::DeleteDir,
/*delete_recursively=*/nullptr,
tf_posix_filesystem::RenameFile,
tf_posix_filesystem::CopyFile,
tf_posix_filesystem::PathExists,
/*paths_exist=*/nullptr,
tf_posix_filesystem::Stat,
/*is_directory=*/nullptr,
/*get_file_size=*/nullptr,
/*translate_name=*/nullptr,
tf_posix_filesystem::GetChildren,
/*get_matching_paths=*/nullptr,
/*flush_caches=*/nullptr,
};
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
for (const char* scheme : {"", "file"})
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
&random_access_file_ops, &writable_file_ops,
&read_only_memory_region_ops, status);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_posix_filesystem::Init;
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
// Initialize the POSIX filesystem.
//
// In general, the `TF_InitPlugin` symbol doesn't need to be exposed in a header
// file, since the plugin registration will look for the symbol in the DSO file
// that provides the filesystem functionality. However, the POSIX filesystem
// needs to be statically registered in some tests and utilities for building
// the API files at the time of creating the pip package. Hence, we need to
// expose this function so that this filesystem can be statically registered
// when needed.
void TF_InitPlugin(TF_FilesystemPluginInfo* info);
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_

View File

@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
}
// Both files have been opened, do the transfer.
// Since errno would be overriden by `close` below, save it here.
// Since errno would be overridden by `close` below, save it here.
int error_code = 0;
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
namespace tensorflow {
// Register the POSIX filesystems statically.
// Return value will be unused
bool StaticallyRegisterLocalFilesystems() {
TF_FilesystemPluginInfo info;
TF_InitPlugin(&info);
Status status = filesystem_registration::RegisterFilesystemPluginImpl(&info);
if (!status.ok()) {
VLOG(0) << "Static POSIX filesystem could not be registered: " << status;
return false;
}
return true;
}
// Perform the actual registration
static bool unused = StaticallyRegisterLocalFilesystems();
} // namespace tensorflow

View File

@ -0,0 +1,36 @@
# Experimental windows filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for Windows environment
tf_cc_shared_object(
name = "windows_filesystem.dll",
framework_so = [],
linkstatic = False,
tags = [
"manual",
"nobuilder",
"notap",
],
visibility = ["//visibility:public"],
deps = [":windows_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "windows_filesystem_impl",
srcs = ["windows_filesystem.cc"],
copts = get_win_copts(),
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(mihaimaruseac): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_windows_filesystem {
// TODO(mihaimaruseac): Implement later
} // namespace tf_windows_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -24,8 +24,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
using tensorflow::ServerFactory;

View File

@ -22,8 +22,8 @@ limitations under the License.
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {

View File

@ -37,9 +37,9 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"

View File

@ -24,9 +24,9 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
namespace tensorflow {

View File

@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
return;
}
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) {
*tensor = result;
}

View File

@ -27,14 +27,10 @@ namespace {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
params.device = &dummy_device;
params.op_kernel = kernel.get();
gtl::InlinedVector<TensorValue, 4> inputs;

View File

@ -18,19 +18,40 @@ limitations under the License.
#include "tensorflow/c/kernels.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include <utility>
#include "absl/container/inlined_vector.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
struct MyCustomKernel {
bool created;
@ -134,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
TEST(TestKernel, TestInputAndOutputCount) {
@ -202,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
{
OpKernelContext::Params p;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
p.device = &dummy_device;
p.step_id = 43;

View File

@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") {
if (op.name() == "AttributeAccessorsOp") {
ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input());

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TENSOR_INTERFACE_H_
#define TENSORFLOW_C_TENSOR_INTERFACE_H_
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to a Tensor.
//
// This allows us to hide concrete implementations of Tensor from header
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorInterface {
public:
// Release any underlying resources, including the interface object.
virtual void Release() = 0;
// Returns tensor dtype.
virtual DataType Type() const = 0;
// Returns number of dimensions.
virtual int NumDims() const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements() const = 0;
// Return size in bytes of the Tensor
virtual size_t ByteSize() const = 0;
// Returns a pointer to tensor data
virtual void* Data() const = 0;
// Returns if the tensor is aligned
virtual bool IsAligned() const = 0;
// Returns if their is sole ownership of this Tensor and thus it can be moved.
virtual bool CanMove() const = 0;
protected:
virtual ~AbstractTensorInterface() {}
};
} // namespace tensorflow
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/error.h"
#include "tensorflow/core/platform/status.h"
using ::tensorflow::IOError;
using ::tensorflow::Status;

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_C_TF_STATUS_HELPER_H_
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_
#define TENSORFLOW_C_TF_STATUS_INTERNAL_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
// Internal structures used by the status C API. These are likely to change
// and should not be depended on.

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include <memory>
#include <vector>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
@ -24,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/platform/casts.h"
using tensorflow::Status;
using tensorflow::Tensor;
@ -63,57 +67,40 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
} // namespace tensorflow
namespace {
class TF_ManagedBuffer : public TensorBuffer {
public:
TF_ManagedBuffer(void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg)
: TensorBuffer(data),
len_(len),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg) {}
const size_t len_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
~TF_ManagedBuffer() override {
(*deallocator_)(data(), len_, deallocator_arg_);
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
const int64_t* dims, int num_dims, size_t len) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
size_t size() const override { return len_; }
TensorBuffer* root_buffer() override { return this; }
void FillAllocationDescription(
tensorflow::AllocationDescription* proto) const override {
tensorflow::int64 rb = size();
proto->set_requested_bytes(rb);
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
// Prevents input forwarding from mutating this buffer.
bool OwnsMemory() const override { return false; }
};
return new TF_Tensor{new tensorflow::TensorInterface(ret)};
}
} // namespace
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
tensorflow::cpu_allocator());
return TF_NewTensor(dtype, dims, num_dims, data, len,
tensorflow::deallocate_buffer,
tensorflow::cpu_allocator());
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
tensorflow::cpu_allocator(), /*owns_memory=*/true);
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
TF_ManagedBuffer* buf = nullptr;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(
@ -128,57 +115,48 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
// Other types have the same representation, so copy only if it is safe to
// do so.
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
len, tensorflow::deallocate_buffer, nullptr);
len, tensorflow::deallocate_buffer, nullptr,
/*owns_memory=*/true);
std::memcpy(buf->data(), data, len);
// Free the original buffer.
deallocator(data, len, deallocator_arg);
} else {
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
}
TF_Tensor* ret =
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf)};
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
delete ret;
return nullptr;
}
return ret;
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return tensor;
}
return nullptr;
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
return t->tensor->CanMove() ? t : nullptr;
}
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
void TF_DeleteTensor(TF_Tensor* t) {
if (t == nullptr) {
return;
}
if (t->tensor) {
t->tensor->Release();
}
delete t;
}
TF_DataType TF_TensorType(const TF_Tensor* t) {
return static_cast<TF_DataType>(t->tensor.dtype());
return static_cast<TF_DataType>(t->tensor->Type());
}
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
return t->tensor->Dim(dim_index);
}
size_t TF_TensorByteSize(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
}
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
void* TF_TensorData(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
}
void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
int64_t TF_TensorElementCount(const TF_Tensor* t) {
int64_t result = 1;
@ -193,16 +171,68 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
TF_Tensor* to, const int64_t* new_dims,
int num_new_dims, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
Status cc_status(
tensorflow::down_cast<tensorflow::TensorInterface*>(to->tensor)
->BitcastFrom(
*tensorflow::down_cast<const tensorflow::TensorInterface*>(
from->tensor),
static_cast<tensorflow::DataType>(type), new_dims, num_new_dims));
Set_TF_Status_from_Status(status, cc_status);
}
namespace tensorflow {
void TensorInterface::Release() { delete this; }
bool TensorInterface::CanMove() const {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return true;
}
return false;
}
DataType TensorInterface::Type() const { return tensor_.dtype(); }
int TensorInterface::NumDims() const { return tensor_.dims(); }
int64_t TensorInterface::Dim(int dim_index) const {
return static_cast<int64_t>(tensor_.dim_size(dim_index));
}
int64_t TensorInterface::NumElements() const {
return static_cast<int64_t>(tensor_.NumElements());
}
size_t TensorInterface::ByteSize() const {
return tensorflow::TensorCApi::Buffer(tensor_)->size();
}
void* TensorInterface::Data() const {
return tensorflow::TensorCApi::Buffer(tensor_)->data();
}
Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type,
const int64_t* new_dims, int num_new_dims) {
tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) {
s.AddDim(new_dims[i]);
}
Status cc_status(to->tensor.BitcastFrom(
from->tensor, static_cast<tensorflow::DataType>(type), s));
Set_TF_Status_from_Status(status, cc_status);
return tensor_.BitcastFrom(from.tensor_, type, s);
}
} // namespace tensorflow
// --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
}
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len);
@ -218,8 +248,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
src_len, "-byte string"));
return 0;
}
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
StringEncode(src, src_len, dst);
return sz;
}
@ -278,13 +307,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
namespace tensorflow {
// Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
*status = tensorflow::Status::OK();
if (!src.IsInitialized()) {
Set_TF_Status_from_Status(
status, FailedPrecondition(
"attempt to use a tensor with an uninitialized value"));
*status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
return nullptr;
}
if (src.NumElements() == 0) {
@ -292,14 +319,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
}
if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) {
Set_TF_Status_from_Status(
status, InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error."));
*status = InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error.");
return nullptr;
}
const string str =
@ -309,12 +335,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
return t;
}
if (src.dtype() != tensorflow::DT_STRING) {
auto* result = new TF_Tensor();
if (!result->tensor.CopyFrom(src, src.shape())) {
delete result;
Tensor tensor;
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return result;
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
}
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings.
@ -338,23 +363,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
*offsets = (dst - data_start);
offsets++;
const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (TF_GetCode(status) != TF_OK) {
Set_TF_Status_from_Status(
status,
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", TF_Message(status)));
delete[] base;
return nullptr;
}
const size_t consumed = TF_StringEncodedSize(s.size());
StringEncode(s.data(), s.size(), dst);
dst += consumed;
dst_len -= consumed;
}
if (dst != base + size) {
Set_TF_Status_from_Status(
status, InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes"));
*status = InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes");
delete[] base;
return nullptr;
}
@ -372,31 +389,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
if (src->tensor.dtype() == DT_RESOURCE) {
if (src->tensor.dims() != 0) {
return tensorflow::down_cast<const tensorflow::TensorInterface*>(src->tensor)
->ToTensor(dst);
}
Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) {
return InvalidArgument(
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
"shape ",
src->tensor.shape().DebugString());
tensor_.shape().DebugString());
}
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
*dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(TF_TensorData(src)),
TF_TensorByteSize(src)))) {
string(static_cast<const char*>(Data()), ByteSize()))) {
return InvalidArgument(
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
}
return Status::OK();
}
if (src->tensor.dtype() != DT_STRING) {
*dst = src->tensor;
if (tensor_.dtype() != DT_STRING) {
*dst = tensor_;
return Status::OK();
}
// TF_STRING tensors require copying since Tensor class expects a sequence of
// string objects.
const tensorflow::int64 num_elements = src->tensor.NumElements();
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
const size_t src_size = TF_TensorByteSize(src);
const tensorflow::int64 num_elements = tensor_.NumElements();
const char* input = reinterpret_cast<const char*>(Data());
const size_t src_size = ByteSize();
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) {
return InvalidArgument(
@ -405,7 +426,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size;
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
*dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =
@ -424,8 +445,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return Status::OK();
}
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
} // namespace tensorflow
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
return tensor->tensor.IsAligned();
}
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }

View File

@ -58,9 +58,9 @@ extern "C" {
// start_offset: array[uint64]
// data: byte[...]
//
// The string length (as a varint), followed by the contents of the string
// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode
// facilitate this encoding.
// The string length (as a varint, start_offset[i + 1] - start_offset[i]),
// followed by the contents of the string is encoded at data[start_offset[i]].
// TF_StringEncode and TF_StringDecode facilitate this encoding.
typedef struct TF_Tensor TF_Tensor;

View File

@ -16,9 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#include <memory>
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/casts.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
@ -27,9 +32,42 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
::tensorflow::Tensor tensor;
tensorflow::AbstractTensorInterface* tensor;
} TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
public:
TF_ManagedBuffer(void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, bool owns_memory)
: TensorBuffer(data),
len_(len),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg),
owns_memory_(owns_memory) {}
~TF_ManagedBuffer() override {
(*deallocator_)(data(), len_, deallocator_arg_);
}
size_t size() const override { return len_; }
TensorBuffer* root_buffer() override { return this; }
void FillAllocationDescription(
tensorflow::AllocationDescription* proto) const override {
tensorflow::int64 rb = size();
proto->set_requested_bytes(rb);
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
bool OwnsMemory() const override { return owns_memory_; }
private:
const size_t len_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
bool owns_memory_;
};
namespace tensorflow {
class TensorCApi {
@ -49,5 +87,38 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
// Defaults to deallocating using CPU allocator. You can pass pointer to
// a different Allocator as `arg`.
void deallocate_buffer(void* data, size_t len, void* arg);
class TensorInterface : public AbstractTensorInterface {
public:
TensorInterface() {}
explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {}
~TensorInterface() override {}
void Release() override;
DataType Type() const override;
int NumDims() const override;
int64_t Dim(int dim_index) const override;
int64_t NumElements() const override;
size_t ByteSize() const override;
void* Data() const override;
bool IsAligned() const override;
bool CanMove() const override;
Status ToTensor(tensorflow::Tensor* dst) const;
Status BitcastFrom(const TensorInterface& from, DataType type,
const int64_t* new_dims, int num_new_dims);
tensorflow::Tensor& Tensor() { return tensor_; }
private:
tensorflow::Tensor tensor_;
};
inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) {
return down_cast<TensorInterface*>(tensor)->Tensor();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -14,10 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
using tensorflow::GraphDef;

View File

@ -41,6 +41,16 @@ filegroup(
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"training/coordinator.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "gradients",
srcs = [
@ -233,6 +243,7 @@ cc_library_with_android_deps(
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:protos_all_cc",
],
)
@ -621,6 +632,7 @@ tf_gen_op_wrappers_cc(
"tpu_configuration_ops",
"tpu_cross_replica_ops",
"tpu_embedding_ops",
"tpu_embedding_load_retrieve_ops",
"tpu_functional_ops",
"tpu_heartbeat_ops",
"tpu_host_compute_ops",

View File

@ -41,7 +41,7 @@ class ClientSession::Impl {
std::shared_ptr<Graph> graph_;
mutable mutex mu_;
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
};
ClientSession::ClientSession(const Scope& scope, const string& target)
@ -127,6 +127,33 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
target_node_names, outputs, run_metadata);
}
Status ClientSession::Run(
const RunOptions& run_options, const FeedType& inputs,
const std::vector<Output>& fetch_outputs,
const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs,
RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) const {
std::vector<std::pair<string, Tensor>> feeds;
for (auto const& feed : inputs) {
TF_RETURN_IF_ERROR(feed.second.status);
feeds.emplace_back(feed.first.name(), feed.second.tensor);
}
std::vector<string> output_tensor_names;
output_tensor_names.reserve(fetch_outputs.size());
for (auto const& output : fetch_outputs) {
output_tensor_names.push_back(output.name());
}
std::vector<string> target_node_names;
target_node_names.reserve(run_outputs.size());
for (auto const& output : run_outputs) {
target_node_names.push_back(output.node()->name());
}
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
return impl()->session_->Run(run_options, feeds, output_tensor_names,
target_node_names, outputs, run_metadata,
threadpool_options);
}
Status ClientSession::MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) {
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());

View File

@ -93,6 +93,14 @@ class ClientSession {
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
/// Same as above. Additionally allows user to provide custom threadpool
/// implementation via ThreadPoolOptions.
Status Run(const RunOptions& run_options, const FeedType& inputs,
const std::vector<Output>& fetch_outputs,
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) const;
/// \brief A handle to a subgraph, created with
/// `ClientSession::MakeCallable()`.
typedef int64 CallableHandle;

View File

@ -112,7 +112,7 @@ TEST(ClientSessionTest, Extend) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({31, 42}, {2}));
}
TEST(ClientSessionTest, MultiThreaded) {
TEST(ClientSessionTest, MultiThreadedWithDefaultThreadpool) {
Scope root = Scope::NewRootScope();
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
@ -138,6 +138,49 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, MultiThreadedWithCustomThreadpool) {
Scope root = Scope::NewRootScope();
int num_threads = 3;
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
ClientSession session(root);
auto inter_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(inter_op_threadpool->GetNumScheduleCalled(), 0);
auto intra_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(intra_op_threadpool->GetNumScheduleCalled(), 0);
tensorflow::thread::ThreadPoolOptions threadPoolOptions;
threadPoolOptions.inter_op_threadpool = inter_op_threadpool.get();
threadPoolOptions.intra_op_threadpool = intra_op_threadpool.get();
{
thread::ThreadPool thread_pool(Env::Default(), "pool", 2);
thread_pool.Schedule([&session, a]() {
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {a}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0],
test::AsTensor<int>({4, 6}, {2}));
});
thread_pool.Schedule([&session, b]() {
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {b}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0],
test::AsTensor<int>({3, 8}, {2}));
});
}
auto c = Sub(root, b, a);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {c}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, CallableWithDefaultThreadPool) {
Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32);

View File

@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
// Used to identify nodes at which to stop backprop.
std::unordered_set<int> GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes);
const std::unordered_set<int>& output_nodes) const;
const Scope& scope_;
const ops::GradOpRegistry* registry_;
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes) {
const std::unordered_set<int>& output_nodes) const {
// Output nodes that get transitively consumed by other `outputs_` are stored
// in `internal_outputs`.
std::unordered_set<int> internal_outputs;
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
"Unable to find backprop list for node.id ", src.node()->name());
}
const auto& grads = iter->second;
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backproped gradients that remain after filtering,
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backpropped gradients that remain after filtering,
// or 'NoGradient' otherwise.
std::vector<Output> grads_to_keep;
for (const Output& o : grads) {
@ -519,17 +519,17 @@ Status SymbolicGradientBuilder::AddGradients() {
// Backprop along the in edges.
// TODO(andydavis) Find cleaner way to map each grad output returned by
// gradient function to the src node/output to which it should be
// backproped. Maybe grad functions can return a vector of Output pairs to
// backpropped. Maybe grad functions can return a vector of Output pairs to
// make this association explicit.
size_t dx_index = 0;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
if (dx_index == dx.size()) {
int dx_index = e->dst_input();
if (dx_index >= dx.size()) {
return errors::Internal(
"Invalid gradient output index: ", dx_index, " size: ", dx.size());
}
TF_RETURN_IF_ERROR(
BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()}));
BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()}));
}
}

View File

@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
}
TEST_F(GradientsTest, AddSymbolicGradientsTest) {
Scope scope = Scope::NewRootScope();
for (int cnt = 0; cnt < 100; ++cnt) {
int N = 5 + rand() % 10;
// Construct forward graph.
OutputList inputs;
for (int i = 0; i < N; ++i) {
auto a = Const(scope, i, {1});
inputs.push_back(a);
}
auto pack = Stack(scope, inputs);
TF_ASSERT_OK(scope.status());
// Construct grad inputs.
OutputList output_grads;
Tensor ts(DT_INT32, {N, 1});
auto v = ts.matrix<int32>();
for (int i = 0; i < N; ++i) {
v(i, 0) = i;
}
auto dy = Const(scope, ts);
output_grads.push_back(dy);
// Call AddSymbolicGradients.
std::vector<Output> grad_outputs;
TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs,
output_grads, &grad_outputs));
ClientSession session((scope));
std::vector<Tensor> in_grad;
TF_ASSERT_OK(session.Run(grad_outputs, &in_grad));
for (int i = 0; i < N; ++i) {
test::ExpectTensorEqual<int>(in_grad[i], test::AsTensor<int>({i}, {1}));
}
}
}
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
// a single nodes output.

View File

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define _USE_MATH_DEFINES
#include <cmath>
#include "tensorflow/cc/ops/array_ops_internal.h"

Some files were not shown because too many files have changed in this diff Show More