Merge branch 'master' into export_SetNumThreads_to_tflite_python
This commit is contained in:
commit
969b77defb
67
.bazelrc
67
.bazelrc
@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
|
||||
build:mkl_threadpool -c opt
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -163,6 +168,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1
|
||||
build:dbg --config=opt -c dbg
|
||||
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
|
||||
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
@ -233,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Enable using platform specific build settings
|
||||
# Enable using platform specific build settings, except when cross-compiling for
|
||||
# mobile platforms.
|
||||
build --enable_platform_specific_config
|
||||
build:android --noenable_platform_specific_config
|
||||
build:ios --noenable_platform_specific_config
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:android --copt=-w
|
||||
build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
@ -256,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build:android --cxxopt=-std=c++14
|
||||
build:android --host_cxxopt=-std=c++14
|
||||
build:ios --cxxopt=-std=c++14
|
||||
build:ios --host_cxxopt=-std=c++14
|
||||
build:linux --cxxopt=-std=c++14
|
||||
build:linux --host_cxxopt=-std=c++14
|
||||
build:macos --cxxopt=-std=c++14
|
||||
@ -356,9 +372,10 @@ build:rbe_linux --linkopt=-lm
|
||||
build:rbe_cpu_linux --config=rbe_linux
|
||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
|
||||
build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
|
||||
build:rbe_linux_cuda_base --config=rbe_linux
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
|
||||
@ -380,17 +397,37 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_
|
||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
|
||||
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
|
||||
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true
|
||||
build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
|
@ -1 +1 @@
|
||||
2.0.0
|
||||
3.0.0
|
||||
|
30
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
30
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
@ -11,25 +11,23 @@ we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
You can also obtain the TensorFlow version with:
|
||||
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
|
@ -38,6 +38,9 @@ state what is wrong:
|
||||
- Producing correct results, but the model is slower than expected (model generated from old converter)
|
||||
|
||||
|
||||
**RNN conversion support**
|
||||
If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title.
|
||||
|
||||
**Any other info / logs**
|
||||
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
||||
|
29
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
29
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
@ -12,25 +12,22 @@ we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:performance_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
You can also obtain the TensorFlow version with:
|
||||
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
|
87
.github/bot_config.yml
vendored
Normal file
87
.github/bot_config.yml
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
- amahendrakar
|
||||
- ravikyram
|
||||
- Saduf2019
|
||||
# A list of assignees for compiler folder
|
||||
compiler_assignees:
|
||||
- joker-eph
|
||||
# Cuda Comment
|
||||
cuda_comment: >
|
||||
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
|
||||
* For TF-GPU - See point 1
|
||||
* For TF-CPU - See point 2
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
|
||||
|
||||
|
||||
Make sure you are using compatible TF and CUDA versions.
|
||||
Please refer following TF version and CUDA version compatibility table.
|
||||
|
||||
| TF | CUDA |
|
||||
|
||||
| :-------------: | :-------------: |
|
||||
|
||||
| 2.1.0 - 2.2.0 | 10.1 |
|
||||
|
||||
| 1.13.1 - 2.0 | 10.0 |
|
||||
|
||||
| 1.5.0 - 1.12.0 | 9.0 |
|
||||
|
||||
* If you have above configuration and using _**Windows**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
|
||||
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
|
||||
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
|
||||
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
|
||||
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
|
||||
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
|
||||
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
|
||||
|
||||
|
||||
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
|
||||
|
||||
|
||||
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
|
||||
|
||||
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
|
||||
|
||||
* Try Google Colab to use TensorFlow.
|
||||
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
|
||||
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
|
||||
* All you need is a good internet connection and you are all set.
|
||||
* Try to build TF from sources by changing CPU optimization flags.
|
||||
|
||||
*Please let us know if this helps.*
|
||||
|
||||
windows_comment: >
|
||||
From the stack trace it looks like you are hitting windows path length limit.
|
||||
* Try to disable path length limit on Windows 10.
|
||||
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
|
||||
|
||||
Please let us know if this helps.
|
39
.github/stale.yml
vendored
Normal file
39
.github/stale.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 7
|
||||
# Number of days of inactivity before a stale Issue or Pull Request is closed
|
||||
daysUntilClose: 7
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
onlyLabels:
|
||||
- stat:awaiting response
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you.
|
||||
# Comment to post when removing the stale label. Set to `false` to disable
|
||||
unmarkComment: false
|
||||
closeComment: >
|
||||
Closing as stale. Please reopen if you'd like to work on this further.
|
||||
limitPerRun: 30
|
||||
# Limit to only `issues` or `pulls`
|
||||
only: issues
|
@ -104,7 +104,7 @@ open-source software development:
|
||||
### Official Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
@ -112,8 +112,8 @@ Build Type | Status
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
|
169
RELEASE.md
169
RELEASE.md
@ -1,3 +1,172 @@
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
|
||||
|
||||
# Release 2.0.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 1.15.3
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 2.2.0
|
||||
|
||||
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
|
||||
|
||||
Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated.
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable.
|
||||
* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines.
|
||||
* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md).
|
||||
* `tf.distribute`:
|
||||
* Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training.
|
||||
* Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy`
|
||||
* Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this.
|
||||
* Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage.
|
||||
* Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation.
|
||||
* Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental.
|
||||
* Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks.
|
||||
* `tf.keras`:
|
||||
* `Model.fit` major improvements:
|
||||
* You can now use custom training logic with `Model.fit` by overriding `Model.train_step`.
|
||||
* Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc)
|
||||
* See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`.
|
||||
* SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of
|
||||
relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models.
|
||||
This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to
|
||||
manually call `Model._set_inputs` when using Custom Training Loops(CTLs).
|
||||
* Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator.
|
||||
This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the
|
||||
`DataAdapter` doesn't need to call the Model is probably preferable.
|
||||
* The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers)
|
||||
* Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode.
|
||||
|
||||
* `tf.lite`:
|
||||
* Enable TFLite experimental new converter by default.
|
||||
* XLA
|
||||
* XLA now builds and works on windows. All prebuilt packages come with XLA available.
|
||||
* XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction
|
||||
) with “compile or throw exception” semantics on CPU and GPU.
|
||||
|
||||
## Breaking Changes
|
||||
* `tf.keras`:
|
||||
* In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer.
|
||||
* Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function.
|
||||
* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`.
|
||||
* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release.
|
||||
* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`.
|
||||
* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed.
|
||||
|
||||
## Known Caveats
|
||||
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* `tf.data`:
|
||||
* Removed `autotune_algorithm` from experimental optimization options.
|
||||
* TF Core:
|
||||
* `tf.constant` always creates CPU tensors irrespective of the current device context.
|
||||
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
|
||||
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
|
||||
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
|
||||
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
|
||||
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
|
||||
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
|
||||
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
|
||||
* Improve error message when attempting to use `None` in data-dependent control flow.
|
||||
* Add `RaggedTensor.numpy()`.
|
||||
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
|
||||
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
|
||||
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
|
||||
* Allow `batch_dims==rank(indices)` in `tf.gather`.
|
||||
* Add support for bfloat16 in `tf.print`.
|
||||
* `tf.distribute`:
|
||||
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
|
||||
* `tf.keras`:
|
||||
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
|
||||
* Allow `pathlib.Path` paths for loading models via Keras API.
|
||||
* `tf.function`/AutoGraph:
|
||||
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
|
||||
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
|
||||
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
|
||||
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
|
||||
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
|
||||
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
|
||||
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
|
||||
* `tf.lite`:
|
||||
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
|
||||
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
|
||||
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
|
||||
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
|
||||
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
|
||||
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
|
||||
* TFLite's unpack op now supports boolean tensor inputs.
|
||||
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
|
||||
* Check for large TFLite tensors.
|
||||
* Fix GPU delegate crash with C++17.
|
||||
* Add 5D support to TFLite `strided_slice`.
|
||||
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
|
||||
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
|
||||
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
|
||||
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
|
||||
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
|
||||
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
|
||||
* `tf.random`:
|
||||
* Various random number generation improvements:
|
||||
* Add a fast path for default `random_uniform`
|
||||
* `random_seed` documentation improvement.
|
||||
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
|
||||
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
|
||||
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
|
||||
* Math and Linear Algebra:
|
||||
* Add `tf.linalg.LinearOperatorTridiag`.
|
||||
* Add `LinearOperatorBlockLowerTriangular`
|
||||
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
|
||||
* Add `tf.math.sobol_sample` op.
|
||||
* Add `tf.math.xlog1py`.
|
||||
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
|
||||
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
|
||||
* TPU Enhancements:
|
||||
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
|
||||
* Support configuring TPU software version from cloud tpu client.
|
||||
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
|
||||
* XLA Support:
|
||||
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
|
||||
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
|
||||
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
|
||||
* Enable `Igamma`, `Igammac` for XLA.
|
||||
* Deterministic Op Functionality:
|
||||
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
|
||||
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
|
||||
* Tracing and Debugging:
|
||||
* Add source, destination name to `_send` traceme to allow easier debugging.
|
||||
* Add traceme event to `fastpathexecute`.
|
||||
* Other:
|
||||
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
|
||||
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
|
||||
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi
|
||||
|
||||
# Release 2.0.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
|
||||
|
||||
It is possible to write models that are secure in a sense that they can safely
|
||||
process untrusted inputs assuming there are no bugs. There are two main reasons
|
||||
to not rely on this: first, it is easy to write models which must not be exposed
|
||||
to not rely on this: First, it is easy to write models which must not be exposed
|
||||
to untrusted inputs, and second, there are bugs in any software system of
|
||||
sufficient complexity. Letting users control inputs could allow them to trigger
|
||||
bugs either in TensorFlow or in dependent libraries.
|
||||
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
|
||||
vulnerability in TensorFlow (although it would be a vulnerability of this
|
||||
hypothetical system).
|
||||
|
||||
As a general rule, it is incorrect behavior for Tensorflow to access memory it
|
||||
As a general rule, it is incorrect behavior for TensorFlow to access memory it
|
||||
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
||||
such behaviors constitute a vulnerability.
|
||||
|
||||
|
2
configure
vendored
2
configure
vendored
@ -4,7 +4,7 @@ set -e
|
||||
set -o pipefail
|
||||
|
||||
if [ -z "$PYTHON_BIN_PATH" ]; then
|
||||
PYTHON_BIN_PATH=$(which python || which python3 || true)
|
||||
PYTHON_BIN_PATH=$(which python3 || which python || true)
|
||||
fi
|
||||
|
||||
# Set all env variables
|
||||
|
49
configure.py
49
configure.py
@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
@ -144,7 +144,7 @@ def write_to_bazelrc(line):
|
||||
|
||||
|
||||
def write_action_env_to_bazelrc(var_name, var):
|
||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
||||
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
|
||||
|
||||
|
||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||
@ -205,7 +205,7 @@ def setup_python(environ_cp):
|
||||
# Get PYTHON_BIN_PATH, default is the current running python.
|
||||
default_python_bin_path = sys.executable
|
||||
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
||||
'%s]: ') % default_python_bin_path
|
||||
'{}]: ').format(default_python_bin_path)
|
||||
while True:
|
||||
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
||||
'PYTHON_BIN_PATH',
|
||||
@ -215,9 +215,10 @@ def setup_python(environ_cp):
|
||||
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
||||
break
|
||||
elif not os.path.exists(python_bin_path):
|
||||
print('Invalid python path: %s cannot be found.' % python_bin_path)
|
||||
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
|
||||
else:
|
||||
print('%s is not executable. Is it the python binary?' % python_bin_path)
|
||||
print('{} is not executable. Is it the python binary?'.format(
|
||||
python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = ''
|
||||
|
||||
# Convert python path to Windows style before checking lib and version
|
||||
@ -236,7 +237,7 @@ def setup_python(environ_cp):
|
||||
default_python_lib_path = python_lib_paths[0]
|
||||
python_lib_path = get_input(
|
||||
'Please input the desired Python library path to use. '
|
||||
'Default is [%s]\n' % python_lib_paths[0])
|
||||
'Default is [{}]\n'.format(python_lib_paths[0]))
|
||||
if not python_lib_path:
|
||||
python_lib_path = default_python_lib_path
|
||||
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
||||
@ -252,7 +253,7 @@ def setup_python(environ_cp):
|
||||
# Set-up env variables used by python_configure.bzl
|
||||
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
|
||||
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
|
||||
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
||||
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
||||
|
||||
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
||||
@ -266,7 +267,7 @@ def setup_python(environ_cp):
|
||||
with open(
|
||||
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
||||
'w') as f:
|
||||
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
|
||||
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
|
||||
|
||||
|
||||
def reset_tf_configure_bazelrc():
|
||||
@ -320,11 +321,12 @@ def get_var(environ_cp,
|
||||
Raise the error to avoid infinitely looping.
|
||||
"""
|
||||
if not question:
|
||||
question = 'Do you wish to build TensorFlow with %s support?' % query_item
|
||||
question = 'Do you wish to build TensorFlow with {} support?'.format(
|
||||
query_item)
|
||||
if not yes_reply:
|
||||
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
|
||||
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
|
||||
if not no_reply:
|
||||
no_reply = 'No %s' % yes_reply
|
||||
no_reply = 'No {}'.format(yes_reply)
|
||||
|
||||
yes_reply += '\n'
|
||||
no_reply += '\n'
|
||||
@ -368,7 +370,7 @@ def get_var(environ_cp,
|
||||
print(no_reply)
|
||||
var = False
|
||||
else:
|
||||
print('Invalid selection: %s' % user_input_origin)
|
||||
print('Invalid selection: {}'.format(user_input_origin))
|
||||
return var
|
||||
|
||||
|
||||
@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version):
|
||||
if which('bazel') is None:
|
||||
print('Cannot find bazel. Please install bazel.')
|
||||
sys.exit(0)
|
||||
curr_version = run_shell(
|
||||
['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
|
||||
|
||||
for line in curr_version.split('\n'):
|
||||
if 'Build label: ' in line:
|
||||
curr_version = line.split('Build label: ')[1]
|
||||
break
|
||||
stderr = open(os.devnull, 'wb')
|
||||
curr_version = run_shell(['bazel', '--version'],
|
||||
allow_non_zero = True,
|
||||
stderr = stderr)
|
||||
if curr_version.startswith('bazel '):
|
||||
curr_version = curr_version.split('bazel ')[1]
|
||||
|
||||
min_version_int = convert_version_to_int(min_version)
|
||||
curr_version_int = convert_version_to_int(curr_version)
|
||||
@ -1171,14 +1173,16 @@ def system_specific_test_config(environ_cp):
|
||||
test_only_filters = ['-oss_serial']
|
||||
if is_windows():
|
||||
test_and_build_filters.append('-no_windows')
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
|
||||
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
|
||||
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
|
||||
else:
|
||||
test_and_build_filters.append('-gpu')
|
||||
elif is_macos():
|
||||
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
|
||||
elif is_linux():
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
|
||||
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
|
||||
test_and_build_filters.append('-no_gpu')
|
||||
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
|
||||
else:
|
||||
@ -1383,7 +1387,6 @@ def main():
|
||||
# Windows.
|
||||
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_MPI'] = '0'
|
||||
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
@ -1416,6 +1419,10 @@ def main():
|
||||
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
|
||||
environ_cp.get('LD_LIBRARY_PATH'))
|
||||
|
||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')):
|
||||
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
|
||||
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
|
||||
|
||||
environ_cp['TF_NEED_CUDA'] = str(
|
||||
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
|
@ -517,12 +517,26 @@ package_group(
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
package_group(
|
||||
name = "types_whitelist",
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
@ -709,8 +723,8 @@ tf_cc_shared_object(
|
||||
"//tensorflow/c:version_script.lds",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:distributed_tensorflow_dependencies",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -116,7 +116,7 @@ from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
@ -126,7 +126,7 @@ from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
@ -58,6 +58,7 @@ filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"conversion_macros.h",
|
||||
"python_api.h",
|
||||
"tensor_interface.h",
|
||||
"tf_status_helper.h",
|
||||
@ -84,7 +85,14 @@ tf_cuda_library(
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:chromiumos": [
|
||||
":tf_attrtype",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:platform",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_attrtype",
|
||||
@ -118,6 +126,13 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_macros",
|
||||
hdrs = ["c_api_macros.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
hdrs = [
|
||||
@ -167,7 +182,7 @@ tf_cuda_library(
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status",
|
||||
@ -204,7 +219,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
@ -217,12 +232,13 @@ cc_library(
|
||||
srcs = ["tf_status.cc"],
|
||||
hdrs = ["tf_status.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
@ -244,10 +260,15 @@ cc_library(
|
||||
name = "tensor_interface",
|
||||
hdrs = ["tensor_interface.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -257,7 +278,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -271,16 +292,17 @@ cc_library(
|
||||
srcs = ["tf_tensor.cc"],
|
||||
hdrs = ["tf_tensor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -296,14 +318,15 @@ tf_cuda_library(
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts",
|
||||
@ -327,6 +350,9 @@ tf_cuda_library(
|
||||
":checkpoint_reader",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -368,8 +394,14 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -408,7 +440,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -439,7 +471,7 @@ tf_cuda_library(
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
@ -466,7 +498,7 @@ tf_cuda_library(
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -517,12 +549,12 @@ tf_cuda_cc_test(
|
||||
":test_op1.so",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
kernels = [":test_op_kernel"],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = [
|
||||
"no_windows", # TODO(b/155444728)
|
||||
"noasan",
|
||||
],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
@ -531,6 +563,7 @@ tf_cuda_cc_test(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_test_util",
|
||||
":test_op_kernel",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:grad_ops",
|
||||
"//tensorflow/cc/saved_model:signature_constants",
|
||||
@ -597,6 +630,7 @@ tf_cc_test(
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -721,3 +755,11 @@ tf_cuda_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conversion_macros",
|
||||
hdrs = [
|
||||
"conversion_macros.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
@ -53,7 +54,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
@ -21,6 +21,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -322,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
|
||||
TF_Status* status) {
|
||||
auto* opts = TFE_NewContextOptions();
|
||||
|
||||
// Reduce GPU memory allocation, and set appropriate config options for TFE
|
||||
// context.
|
||||
auto* config = TF_CreateConfig(
|
||||
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
10);
|
||||
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
|
||||
if (!status->status.ok()) {
|
||||
CHECK(!config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* ctx = TFE_NewContextFromSession(opts, session, status);
|
||||
TF_DeleteBuffer(config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// TODO: retrieve the device string via TFE_ContextListDevices()
|
||||
static const char DEFAULT_CPU_DEVICE[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
|
||||
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
int tensor_id, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
|
||||
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
|
||||
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
|
||||
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
|
||||
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
|
||||
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
|
||||
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
|
||||
shared_name.size());
|
||||
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
|
||||
|
||||
// TODO: consider making this an unknown shape.
|
||||
const int64_t* dims_ptr = nullptr;
|
||||
int num_dims = 0;
|
||||
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
|
||||
/*num_values*/ 0, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* queue = nullptr;
|
||||
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
|
||||
return queue;
|
||||
}
|
||||
|
||||
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, tensor, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
|
||||
if (!status->status.ok()) return;
|
||||
CHECK_EQ(num_retvals, 0);
|
||||
}
|
||||
|
||||
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
|
||||
TF_DataType inputType,
|
||||
TFE_TensorHandle* queue,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
TFE_TensorHandle* ret;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &ret, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||
status->status = tensorflow::errors::Internal(errMsg);
|
||||
}
|
||||
@ -619,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
|
||||
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
auto iter = builder->attr_names.insert(attr_name).first;
|
||||
builder->Set(
|
||||
(*iter).c_str(),
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values),
|
||||
num_values));
|
||||
}
|
||||
|
||||
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
|
||||
@ -686,8 +489,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
||||
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorHandle{
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -708,7 +510,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
@ -822,14 +624,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation->Name());
|
||||
node_def.set_op(tfe_op->operation->Name());
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
OperationFromInterface(tfe_op->operation)
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
|
@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
|
||||
// Create a serialized tensorflow.ServerDef proto.
|
||||
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
|
||||
|
||||
// TODO: remove this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
|
||||
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
|
||||
|
||||
// Creates from `session` a new eager context to run a graph function or
|
||||
// sends/recvs, so that these concurrent TFE executions can share (via
|
||||
// `session` and its associated device mgr) the same set of fifo queue resource
|
||||
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
|
||||
// graph function execution can access the same fifo queue resource handles
|
||||
// (associated with devices managed by the device manager, which can be obtained
|
||||
// from `session`).
|
||||
//
|
||||
// TODO: Remove this function once we migrate away from using session.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
|
||||
TF_Session* session, TF_Status* status);
|
||||
|
||||
// TODO: Retire this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
|
||||
TF_Session* session, int tensor_id, TF_DataType inputType,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO: consider folding the 2 APIs below into the ones above.
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||
TF_Session* session, int tensor_id, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||
const char* errMsg);
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
33
tensorflow/c/c_api_macros.h
Normal file
33
tensorflow/c/c_api_macros.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_C_API_MACROS_H_
|
||||
#define TENSORFLOW_C_C_API_MACROS_H_
|
||||
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#endif // TENSORFLOW_C_C_API_MACROS_H_
|
33
tensorflow/c/conversion_macros.h
Normal file
33
tensorflow/c/conversion_macros.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
#define TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
|
||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||
inline cpp_impl *unwrap(wrapper *w) { \
|
||||
return reinterpret_cast<cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline const cpp_impl *unwrap(const wrapper *w) { \
|
||||
return reinterpret_cast<const cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
|
||||
inline const wrapper *wrap(const cpp_impl *i) { \
|
||||
return reinterpret_cast<const wrapper *>(i); \
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_
|
@ -35,18 +35,26 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
":tfe_monitoring_internal",
|
||||
":tfe_op_attrs_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensor_debug_info_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
@ -100,6 +108,11 @@ filegroup(
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
"tfe_op_attrs_internal.h",
|
||||
"tfe_tensor_debug_info_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
@ -107,33 +120,27 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
cc_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:c_api",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_context_internal",
|
||||
":tfe_executor_internal",
|
||||
":tfe_monitoring_internal",
|
||||
":tfe_op_attrs_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensor_debug_info_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
],
|
||||
)
|
||||
|
||||
@ -177,13 +184,110 @@ cc_library(
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_context_internal",
|
||||
hdrs = ["tfe_context_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_cancellation_manager_internal",
|
||||
hdrs = ["tfe_cancellation_manager_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_executor_internal",
|
||||
hdrs = ["tfe_executor_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_monitoring_internal",
|
||||
hdrs = ["tfe_monitoring_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_op_attrs_internal",
|
||||
hdrs = ["tfe_op_attrs_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_op_internal",
|
||||
hdrs = ["tfe_op_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_tensor_debug_info_internal",
|
||||
hdrs = ["tfe_tensor_debug_info_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_tensorhandle_internal",
|
||||
hdrs = ["tfe_tensorhandle_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_test_util",
|
||||
testonly = 1,
|
||||
@ -213,7 +317,9 @@ tf_cuda_cc_test(
|
||||
],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"guitar",
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
"notap", # TODO(b/156981931): flaky
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -221,6 +327,8 @@ tf_cuda_cc_test(
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -239,17 +347,49 @@ tf_cuda_cc_test(
|
||||
srcs = [
|
||||
"c_api_remote_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"guitar",
|
||||
"multi_gpu",
|
||||
"no_oss",
|
||||
"noasan", # leaks gRPC server instances
|
||||
"notsan", # b/157098283
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_cluster_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_cluster_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -257,6 +397,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:env",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -266,6 +407,9 @@ tf_cuda_library(
|
||||
srcs = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_unified_experimental.cc",
|
||||
"c_api_unified_experimental_eager.cc",
|
||||
"c_api_unified_experimental_graph.cc",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
@ -275,11 +419,14 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -308,6 +455,8 @@ tf_cuda_library(
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
@ -362,6 +511,7 @@ tf_cuda_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
@ -443,8 +593,9 @@ cc_library(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
@ -466,7 +617,6 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"*c_api_tfrt*",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
@ -34,9 +33,12 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "tensorflow/c/eager/c_api_tfrt.h"
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
@ -498,6 +500,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// For cluster update, use a status group to aggregate statuses from
|
||||
// * adding and removing remote devices
|
||||
// * creating remote contexts on newly added workers
|
||||
// * updating remote contexts on existing workers
|
||||
// * updating the master context
|
||||
// Note that we should not return immediately on errors in the middle of these
|
||||
// updates to prevent cluster from having inconsistent context views.
|
||||
//
|
||||
// Unused if `reset_context` is True.
|
||||
tensorflow::StatusGroup sg;
|
||||
|
||||
// When updating an existing context, populate the following lists with:
|
||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||
@ -533,7 +546,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
|
||||
sg.Update(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
@ -557,10 +570,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
existing_workers.end());
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
|
||||
added_workers, grpc_server->master_env()->worker_cache,
|
||||
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||
grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
}
|
||||
|
||||
@ -582,7 +594,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
|
||||
// Initialize remote eager workers.
|
||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
@ -594,7 +605,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// existing workers to also have the updated context_view_id, so
|
||||
// we must set their context_view_id to the existing master's
|
||||
// context_view_id + 1.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
@ -604,20 +615,19 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
||||
ctx, existing_workers, added_workers, removed_workers, context_id,
|
||||
context_view_id + 1, server_def, remote_eager_workers.get(),
|
||||
base_request));
|
||||
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
|
||||
if (reset_context) {
|
||||
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
@ -644,13 +654,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
|
||||
grpc_server->worker_env(), std::move(remote_eager_workers),
|
||||
added_workers, removed_workers, context_id, r));
|
||||
/*isolate_session_state=*/true));
|
||||
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||
std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
@ -684,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_Context{new tfrt::ContextInterface()};
|
||||
tfrt::SmallVector<std::string, 4> op_handler_chains;
|
||||
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
|
||||
status->status = tfrt::ListOpHandlerChains(
|
||||
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::wrap(
|
||||
new tfrt::ContextInterface(op_handler_chains, device_attributes));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
@ -702,32 +717,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
TF_Session* sess, TF_Status* status) {
|
||||
const tensorflow::DeviceMgr* device_mgr = nullptr;
|
||||
status->status = sess->session->LocalDeviceManager(&device_mgr);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -735,23 +732,18 @@ void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
return;
|
||||
}
|
||||
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
ctx->context->Release();
|
||||
|
||||
delete ctx;
|
||||
// ctx->RefCountIsOne() should be true here.
|
||||
tensorflow::unwrap(ctx)->Release();
|
||||
}
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
TF_DeviceList* l = new TF_DeviceList;
|
||||
ctx->context->ListDevices(&l->response);
|
||||
tensorflow::unwrap(ctx)->ListDevices(&l->response);
|
||||
return l;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->ClearCachesAndThreadExecutors();
|
||||
tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
|
||||
}
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
@ -773,7 +765,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
@ -782,7 +774,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
@ -804,7 +796,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::ServerDef server_def;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
if (!server_def.ParseFromArray(proto, proto_len)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
@ -834,7 +826,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
|
||||
@ -889,16 +881,14 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->SyncExecutors();
|
||||
status->status = tensorflow::unwrap(ctx)->AsyncWait();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
@ -909,18 +899,17 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
return new TFE_TensorHandle{
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
@ -928,84 +917,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
if (h->handle) {
|
||||
h->handle->Release();
|
||||
if (h) {
|
||||
tensorflow::unwrap(h)->Release();
|
||||
}
|
||||
delete h;
|
||||
}
|
||||
|
||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||
return static_cast<TF_DataType>(h->handle->DataType());
|
||||
return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
|
||||
}
|
||||
|
||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int num_dims = -1;
|
||||
status->status = h->handle->NumDims(&num_dims);
|
||||
status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
|
||||
return num_dims;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 num_elements = -1;
|
||||
status->status = h->handle->NumElements(&num_elements);
|
||||
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 dim = -1;
|
||||
status->status = h->handle->Dim(dim_index, &dim);
|
||||
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
|
||||
return dim;
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return h->handle->DeviceName(&status->status);
|
||||
return tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return h->handle->BackingDeviceName(&status->status);
|
||||
return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle{h->handle->Copy()};
|
||||
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status);
|
||||
tensorflow::AbstractTensorInterface* t =
|
||||
tensorflow::unwrap(h)->Resolve(&status->status);
|
||||
if (t == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1014,22 +1003,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
return t->data();
|
||||
}
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
|
||||
"handle.");
|
||||
"TFE_TensorHandleDevicePointer may not be called on a ",
|
||||
handle->TypeString(), " tensor handle.");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
||||
@ -1055,7 +1044,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device = nullptr;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
@ -1081,11 +1070,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
if (custom_device == nullptr) {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context));
|
||||
} else {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1094,16 +1083,16 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
|
||||
"handle.");
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
|
||||
handle->TypeString(), " tensor handle.");
|
||||
return 0;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
@ -1116,12 +1105,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
|
||||
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||
tensorflow::AbstractOperationInterface* new_op =
|
||||
tensorflow::unwrap(ctx)->CreateOperation();
|
||||
status->status = new_op->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
new_op->Release();
|
||||
new_op = nullptr;
|
||||
}
|
||||
return new_op.release();
|
||||
return tensorflow::wrap(new_op);
|
||||
}
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) {
|
||||
@ -1129,24 +1120,20 @@ void TFE_DeleteOp(TFE_Op* op) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (op->operation) {
|
||||
op->operation->Release();
|
||||
}
|
||||
|
||||
delete op;
|
||||
tensorflow::unwrap(op)->Release();
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = op->operation->SetDeviceName(device_name);
|
||||
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
return op->operation->DeviceName().c_str();
|
||||
return tensorflow::unwrap(op)->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = op->operation->SetUseXla(enable);
|
||||
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
@ -1157,18 +1144,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
status->status = op->operation->AddInput(input->handle);
|
||||
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
|
||||
num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
handles[i] = inputs[i]->handle;
|
||||
}
|
||||
status->status =
|
||||
op->operation->AddInputList({handles.data(), handles.size()});
|
||||
status->status = tensorflow::unwrap(op)->AddInputList(
|
||||
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
@ -1176,8 +1158,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
const tensorflow::AttrTypeMap* attr_types_;
|
||||
bool is_function;
|
||||
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
|
||||
&attr_types_, &is_function);
|
||||
status->status = tensorflow::AttrTypeMapForOp(
|
||||
tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return ret;
|
||||
}
|
||||
@ -1203,7 +1185,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
||||
|
||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
size_t length) {
|
||||
auto s = op->operation->SetAttrString(
|
||||
auto s = tensorflow::unwrap(op)->SetAttrString(
|
||||
attr_name, static_cast<const char*>(value), length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
@ -1211,29 +1193,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||
auto s = op->operation->SetAttrInt(attr_name, value);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||
auto s = op->operation->SetAttrFloat(attr_name, value);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
|
||||
(value == 0) ? false : true);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
auto s = op->operation->SetAttrType(attr_name,
|
||||
static_cast<tensorflow::DataType>(value));
|
||||
auto s = tensorflow::unwrap(op)->SetAttrType(
|
||||
attr_name, static_cast<tensorflow::DataType>(value));
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1241,12 +1224,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
|
||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||
const int num_dims, TF_Status* out_status) {
|
||||
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
|
||||
out_status->status =
|
||||
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op* value) {
|
||||
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunction(
|
||||
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1254,7 +1239,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1265,14 +1250,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
||||
tensorflow::Tensor t;
|
||||
status->status = TF_TensorToTensor(tensor, &t);
|
||||
tensorflow::TensorInterface interface(t);
|
||||
status->status = op->operation->SetAttrTensor(attr_name, &interface);
|
||||
status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values) {
|
||||
auto s =
|
||||
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
|
||||
num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1280,7 +1265,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1288,7 +1274,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1296,7 +1283,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
auto s = op->operation->SetAttrTypeList(
|
||||
auto s = tensorflow::unwrap(op)->SetAttrTypeList(
|
||||
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
|
||||
num_values);
|
||||
if (!s.ok()) {
|
||||
@ -1306,7 +1293,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1315,19 +1303,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status) {
|
||||
out_status->status =
|
||||
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
|
||||
attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
absl::FixedArray<const tensorflow::AbstractOperationInterface*> values(
|
||||
num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
values[i] = value[i]->operation;
|
||||
}
|
||||
auto s = op->operation->SetAttrFunctionList(attr_name,
|
||||
{values.data(), values.size()});
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
||||
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1342,12 +1325,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
||||
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
||||
return;
|
||||
}
|
||||
if (op == nullptr || op->operation == nullptr) {
|
||||
if (op == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Got a null or uninitialized `op` argument");
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
|
||||
operation->MutableAttrs()->Set(attr_name, attr_value);
|
||||
}
|
||||
|
||||
@ -1355,7 +1339,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
int ret = -1;
|
||||
status->status = op->operation->InputLength(input_name, &ret);
|
||||
status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -1363,71 +1347,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
const char* output_name,
|
||||
TF_Status* status) {
|
||||
int ret = -1;
|
||||
status->status = op->operation->OutputLength(output_name, &ret);
|
||||
status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
|
||||
*num_retvals);
|
||||
status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle{handles[i]};
|
||||
}
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Device* device;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(device_name, &dev);
|
||||
auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
|
||||
tensorflow::unwrap(h), device_name, &status->status);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorToDevice(
|
||||
tensorflow::TensorHandleFromInterface(h->handle), &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{handle};
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
// Handle tensor handles currently in custom devices
|
||||
const char* handle_device_name = h->handle->DeviceName(&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorFromDevice(
|
||||
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{handle};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle regular case.
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::TensorHandleFromInterface(h->handle), context,
|
||||
&context->Executor(), device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{handle};
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -1442,39 +1384,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
@ -1482,13 +1424,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
TF_Status* status) {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
|
||||
}
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
@ -1510,26 +1452,23 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
||||
} // namespace
|
||||
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->StartStep();
|
||||
tensorflow::unwrap(ctx)->StartStep();
|
||||
}
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->EndStep();
|
||||
tensorflow::unwrap(ctx)->EndStep();
|
||||
}
|
||||
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
return tensorflow::wrap(
|
||||
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(op));
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (const auto& attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
@ -1539,8 +1478,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(attrs->name);
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
|
||||
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||
}
|
||||
|
||||
@ -1587,6 +1526,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
||||
// consider adding an overload to the C API to make this case easier.
|
||||
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
||||
TFE_DeleteOp(func_op);
|
||||
} break;
|
||||
case tensorflow::AttrValue::kList:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
@ -1616,33 +1556,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle* handle,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{tensor};
|
||||
handle->Ref();
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
|
||||
tensor_handle.handle->Release();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
|
||||
context_, tensorflow::wrap(handle), &status, info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle* handle,
|
||||
const tensorflow::string& target_device_name,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
TF_Status status;
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{tensor};
|
||||
handle->Ref();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
tensor_handle.handle->Release();
|
||||
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
|
||||
info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
@ -1655,16 +1596,17 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]});
|
||||
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(outputs[i]));
|
||||
retvals[i]->Ref();
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
@ -1692,7 +1634,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
|
||||
TF_Status* status);
|
||||
// Indicates that the caller will not be using `h` any more.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
|
433
tensorflow/c/eager/c_api_cluster_test.cc
Normal file
433
tensorflow/c/eager/c_api_cluster_test.cc
Normal file
@ -0,0 +1,433 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->at(task_index) =
|
||||
tensorflow::strings::StrCat("localhost:", port);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
|
||||
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
|
||||
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
for (int i = 0; i < expected_values.size(); i++) {
|
||||
EXPECT_EQ(expected_values[i], actual_values[i])
|
||||
<< "Mismatch in expected values at (zero-based) index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
const char* remote_device_name,
|
||||
const char* local_device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
auto* retval_task0 =
|
||||
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
|
||||
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Read the value of variable `var` and save it into `out_value`.
|
||||
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
|
||||
TFE_TensorHandle** out_value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, out_value, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
|
||||
// Update the server def with a new set of names (worker instead of
|
||||
// localhost).
|
||||
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
|
||||
serialized = updated_server_def.SerializeAsString();
|
||||
|
||||
updated_server_def.set_task_index(1);
|
||||
tensorflow::Status s = tensorflow::GrpcServer::Create(
|
||||
updated_server_def, tensorflow::Env::Default(), &worker_server);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
|
||||
|
||||
// Check that copying it to the old remote device (named localhost) fails.
|
||||
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Copying and executing on the new remote device works.
|
||||
const char new_remote_device_name[] =
|
||||
"/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char new_local_device_name[] =
|
||||
"/job:worker/replica:0/task:0/device:CPU:0";
|
||||
|
||||
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
|
||||
h0_task0_new, ctx, new_remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0_new);
|
||||
TFE_DeleteTensorHandle(h0_task1_new);
|
||||
|
||||
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
|
||||
new_local_device_name);
|
||||
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteChangeServerDef) {
|
||||
TestRemoteExecuteChangeServerDef(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
|
||||
TestRemoteExecuteChangeServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDef) {
|
||||
TestRemoteExecuteUpdateServerDef(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
|
||||
TFE_TensorHandle* value_handle = nullptr;
|
||||
ReadVariable(ctx, var_handle1, &value_handle);
|
||||
CheckTFE_TensorHandleHasFloats(value_handle, {2});
|
||||
TFE_DeleteTensorHandle(value_handle);
|
||||
|
||||
// Start a new worker to replace task:1
|
||||
ReplaceTaskInServerDef(&server_def, 1);
|
||||
server_def.set_task_index(1);
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
// Update server def to replace the remote device with the device info on the
|
||||
// new worker (different incarnation ID).
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// The device of var_handle0 is local device which is the same before and
|
||||
// after cluster update. Remove resource with valid device should succeed.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle0, status);
|
||||
TFE_OpSetDevice(op, dev0_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
// The device of var_handle1 is remote device, which was replaced during
|
||||
// cluster update. Removing resource with invalid device should fail
|
||||
// gracefully (i.e., with error status) instead of crashing with segfaults.
|
||||
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle1, status);
|
||||
TFE_OpSetDevice(op, dev1_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
// Adding a non-existent remote worker to cluster def. This should cause the
|
||||
// UpdateServerDef call to fail.
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Even after the prevoiusly failed cluster update, another update and op
|
||||
// execution should work fine as long as the provided server_def is valid.
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -17,8 +17,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
@ -54,7 +57,8 @@ extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandle* handle =
|
||||
TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
|
@ -19,7 +19,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
@ -34,9 +38,10 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->operation->Clear();
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
@ -45,13 +50,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
@ -483,7 +488,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
||||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
@ -494,7 +499,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
|
||||
}
|
||||
|
||||
@ -530,7 +535,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = tensorflow::Status::OK();
|
||||
@ -557,19 +562,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return new TFE_Executor(&context->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
context->HostCPU()->parsed_name());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
@ -585,7 +590,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
@ -611,12 +616,13 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (ctx == nullptr || ctx->context == nullptr) {
|
||||
if (ctx == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
|
||||
tensorflow::AbstractTensorInterface* t =
|
||||
tensorflow::unwrap(ctx)->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dimvec);
|
||||
|
||||
if (t == nullptr) {
|
||||
@ -630,5 +636,38 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
TF_Status* status) {
|
||||
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
TFE_TensorHandle** handles,
|
||||
int* num_handles,
|
||||
TF_Status* status) {
|
||||
std::vector<tensorflow::TensorHandle*> tensor_handles;
|
||||
tensor_handles.reserve(*num_handles);
|
||||
for (int i = 0; i < *num_handles; ++i) {
|
||||
tensor_handles.push_back(
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::TensorHandle::CreatePackedHandle(
|
||||
std::move(tensor_handles), context, &handle);
|
||||
return tensorflow::wrap(handle);
|
||||
}
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
@ -431,11 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a struct with a reference to information about attributes of `op`.
|
||||
//
|
||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
|
||||
// Fetch a reference to `op`'s attributes. The returned reference is only valid
|
||||
// while `op` is alive.
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
@ -543,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||
|
||||
// Create a packed TensorHandle with the given list of TensorHandles.
|
||||
// If `handles` are on the same device, assign the same device to the packed
|
||||
// handle; if `handles` are on different deivces, assign a CompositeDevice to
|
||||
// it.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure soft device placement policy for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure device placement policy logging for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -15,39 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
|
||||
|
||||
// TODO(b/154564140): Move this to its own header. This requires splitting
|
||||
// c_api_experimental.h
|
||||
struct TFE_ContextOptions {
|
||||
TF_SessionOptions session_options;
|
||||
// true if async execution is enabled.
|
||||
@ -61,199 +39,4 @@ struct TFE_ContextOptions {
|
||||
bool use_tfrt = false;
|
||||
};
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying context object. Instead, call
|
||||
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
||||
// the TFE_Context structure.
|
||||
struct TFE_Context {
|
||||
tensorflow::AbstractContextInterface* context;
|
||||
};
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying handle object. Instead, call
|
||||
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
||||
// the TFE_TensorHandle structure.
|
||||
struct TFE_TensorHandle {
|
||||
tensorflow::AbstractTensorHandleInterface* handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
: dev_dims(dims) {}
|
||||
|
||||
// Fully-padded, minor-to-major.
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying operation object. Instead, call
|
||||
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
||||
// the TFE_Op structure.
|
||||
struct TFE_Op {
|
||||
tensorflow::AbstractOperationInterface* operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringCounter {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringCounter(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||
};
|
||||
struct TFE_MonitoringStringGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||
};
|
||||
struct TFE_MonitoringBoolGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||
};
|
||||
|
||||
template <typename ValueType, int NumLabels>
|
||||
struct TFE_MonitoringGauge {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringGauge(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
gauge = absl::WrapUnique(
|
||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
explicit TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
create_buckets;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSamplerCell {
|
||||
tensorflow::monitoring::SamplerCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringSampler {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringSampler(
|
||||
const char* name,
|
||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||
const char* description, LabelDesc&&... label) {
|
||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||
{name, description, label...}, std::move(buckets)));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
const char* attr_name, TF_Status* status);
|
||||
} // namespace tensorflow
|
||||
|
||||
struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
|
||||
struct TFE_Executor {
|
||||
explicit TFE_Executor(bool async)
|
||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||
|
||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||
|
||||
tensorflow::EagerExecutor* executor() {
|
||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||
const char* op_name)
|
||||
: name(op_name), attributes(value) {}
|
||||
|
||||
const char* name;
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -17,12 +17,18 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
@ -129,7 +135,49 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
string MatMulFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'MatMulFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" input_arg {"
|
||||
" name: 'b'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'm'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'matmul'"
|
||||
" op: 'MatMul'"
|
||||
" input: 'a'"
|
||||
" input: 'b'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'm'"
|
||||
" value: 'matmul:product'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
|
||||
// which creates a remote remote input, to simulate a scenario that the remote
|
||||
// input is not ready when we start running an op or a function.
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
bool heavy_load_on_streaming_rpc) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -154,48 +202,87 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
std::vector<TFE_TensorHandle*> handles_task0;
|
||||
if (heavy_load_on_streaming_rpc) {
|
||||
// Send 50 tensor copy requests to simulate that there have been some RPC
|
||||
// requests been enqueued.
|
||||
for (int i = 0; i < 50; ++i) {
|
||||
handles_task0.push_back(TestMatrixTensorHandle(ctx));
|
||||
}
|
||||
}
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
std::vector<TFE_TensorHandle*> handles_task2;
|
||||
for (auto* h_task0 : handles_task0) {
|
||||
handles_task2.push_back(
|
||||
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = nullptr;
|
||||
if (func) {
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h0_task0, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h1_task2, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else {
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
}
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else if (!async) {
|
||||
// Set the local device to CPU to easily validate mirroring
|
||||
string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(matmul->operation);
|
||||
if (!remote && !async) {
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->Inputs()[1], remote_arg);
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
@ -210,13 +297,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h1_task2);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
for (auto* h : handles_task0) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
for (auto* h : handles_task2) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
if (func) {
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
@ -227,16 +323,435 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, true);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
// A remote input may be not ready when we start running a function. Test that
|
||||
// the function execution should wait until the remote input is ready.
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
}
|
||||
|
||||
// Add the values of three variables on three different tasks.
|
||||
string AddVariablesFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'AddVariablesFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'sum'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read1'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read2'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add1'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read1:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add2'"
|
||||
" op: 'Add'"
|
||||
" input: 'add1:z:0'"
|
||||
" input: 'read2:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'sum'"
|
||||
" value: 'add2:z:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
||||
CHECK_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
||||
bool initialized = false;
|
||||
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
EXPECT_EQ(initialized, true);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(is_initialized[0]);
|
||||
TFE_DeleteOp(op);
|
||||
delete status;
|
||||
}
|
||||
|
||||
void TestFunctionWithPackedInput(const bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
// Create one variable per task.
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Add a sync point in order to make sure that variables have been initialized
|
||||
// before the function execution starts.
|
||||
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
||||
VarIsInitialized(ctx, h1);
|
||||
VarIsInitialized(ctx, h2);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
TFE_TensorHandle* packed_handle =
|
||||
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||
|
||||
const string composite_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Register and run a function which returns the sum of 3 variables.
|
||||
const string function_def = AddVariablesFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(func, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(packed_handle);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(sum, 6.0);
|
||||
|
||||
TFE_DeleteTensorHandle(h0);
|
||||
TFE_DeleteTensorHandle(h1);
|
||||
TFE_DeleteTensorHandle(h2);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/true);
|
||||
}
|
||||
|
||||
string VariableAddFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'VariableAddFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var0'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'var0_value'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read0:value:0'"
|
||||
" device: '/job:localhost/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'identity'"
|
||||
" op: 'Identity'"
|
||||
" input: 'add:z:0'"
|
||||
" device: '/job:localhost/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'var0_value'"
|
||||
" value: 'identity:output:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
||||
public:
|
||||
FunctionErrorInjectionPass(string error_node, string error_device)
|
||||
: error_node_(error_node), error_device_(error_device) {}
|
||||
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
std::unique_ptr<tensorflow::Graph>* graph,
|
||||
tensorflow::FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override {
|
||||
// Inject failure to function instantiation if finding a node that contains
|
||||
// the given node name (error_node_) and requested device (error_device_).
|
||||
for (const auto node : graph->get()->nodes()) {
|
||||
if (node->name().find(error_node_) != string::npos &&
|
||||
node->requested_device() == error_device_) {
|
||||
return tensorflow::errors::Internal("Injected graph pass error.");
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const string error_node_;
|
||||
const string error_device_;
|
||||
};
|
||||
|
||||
void TestDistributedFunctionCancellation(bool inject_error) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
if (inject_error) {
|
||||
// Inject a function optimization pass failure when it sees the 'read0' op
|
||||
// having a requested device `dev2_name`. During execution:
|
||||
// * task:0 processes the main function `VariableAddFunction` and places
|
||||
// the read0 op on task:2
|
||||
// * task:0 partitions the main function with a subgraph containing read0
|
||||
// sent to task:2
|
||||
// * task:2 graph pass reports an error when it sees read0 with dev2_name
|
||||
tensorflow::function_optimization_registration::
|
||||
FunctionOptimizationPassRegistration register_test_pass(
|
||||
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle, nullptr);
|
||||
|
||||
const string function_def = VariableAddFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, var_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
|
||||
if (inject_error) {
|
||||
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
ASSERT_EQ(sum, 4.0);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(var_handle);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
@ -309,150 +824,4 @@ TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
|
||||
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
|
||||
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
for (int i = 0; i < expected_values.size(); i++) {
|
||||
EXPECT_EQ(expected_values[i], actual_values[i])
|
||||
<< "Mismatch in expected values at (zero-based) index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
const char* remote_device_name,
|
||||
const char* local_device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
auto* retval_task0 =
|
||||
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
|
||||
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
|
||||
// Update the server def with a new set of names (worker instead of
|
||||
// localhost).
|
||||
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
|
||||
serialized = updated_server_def.SerializeAsString();
|
||||
|
||||
updated_server_def.set_task_index(1);
|
||||
tensorflow::Status s = tensorflow::GrpcServer::Create(
|
||||
updated_server_def, tensorflow::Env::Default(), &worker_server);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
|
||||
|
||||
// Check that copying it to the old remote device (named localhost) fails.
|
||||
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Copying and executing on the new remote device works.
|
||||
const char new_remote_device_name[] =
|
||||
"/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char new_local_device_name[] =
|
||||
"/job:worker/replica:0/task:0/device:CPU:0";
|
||||
|
||||
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
|
||||
h0_task0_new, ctx, new_remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0_new);
|
||||
TFE_DeleteTensorHandle(h0_task1_new);
|
||||
|
||||
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
|
||||
new_local_device_name);
|
||||
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteChangeServerDef) {
|
||||
TestRemoteExecuteChangeServerDef(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
|
||||
TestRemoteExecuteChangeServerDef(true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -78,11 +80,18 @@ void BM_Execute(int iters, int async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(matmul, "MatMul", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -113,11 +122,15 @@ void BM_Execute_Identity(int iters, int async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identity = IdentityOp(ctx, m);
|
||||
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(identity, "Identity", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(identity, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(identity, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -405,6 +418,13 @@ void TensorHandleSilentCopy(bool async,
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
auto cpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
|
||||
auto gpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
|
||||
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
if (cpu_op) {
|
||||
string cpu_device_name;
|
||||
@ -420,15 +440,8 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Validate if the input was replaced with a different TensorHandle
|
||||
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(matmul->operation);
|
||||
|
||||
// The input handles should never change since they have been mirrored.
|
||||
EXPECT_EQ(op->Inputs()[0], arg0);
|
||||
EXPECT_EQ(op->Inputs()[1], arg1);
|
||||
// The CPU handle should have been copied and have a mirror on the GPU
|
||||
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
@ -626,17 +639,6 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
|
||||
}
|
||||
|
||||
int num_retvals = 1;
|
||||
|
||||
if (async) {
|
||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
TFE_TensorHandle* dummy = nullptr;
|
||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(dummy);
|
||||
}
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retval = nullptr;
|
||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
@ -1130,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
}
|
||||
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
||||
|
||||
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
||||
TF_Status* status) {
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TEST(CAPI, Variables) {
|
||||
// Variables use resource handles, so this is really a test for resource
|
||||
// tensor handling.
|
||||
@ -1184,7 +1141,7 @@ TEST(CAPI, Variables) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
@ -1225,7 +1182,7 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
@ -1246,6 +1203,8 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
h = nullptr;
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(op);
|
||||
@ -1348,7 +1307,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->Attrs().FillAttrValueMap(&attr_values);
|
||||
return attr_values;
|
||||
}
|
||||
@ -1484,10 +1443,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->operation->OpDef());
|
||||
CHECK(tensorflow::unwrap(concatOp)->OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||
EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -1579,7 +1538,7 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TEST(CAPI, TestTFE_OpAddAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
@ -1589,12 +1548,11 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddAttrs(copy_op, &attributes);
|
||||
TFE_OpAddAttrs(copy_op, attributes);
|
||||
unsigned char is_list = 0;
|
||||
ASSERT_EQ(TF_ATTR_TYPE,
|
||||
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||
@ -1605,7 +1563,7 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(copy_op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
@ -1626,11 +1584,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||
|
||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||
TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||
@ -1653,7 +1610,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(var_op_2->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||
|
||||
|
@ -133,6 +133,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (!device_name.empty()) {
|
||||
TFE_OpSetDevice(op, device_name.c_str(), status);
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
|
@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return a variable handle referring to a variable with the given initial value
|
||||
// on the given device.
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name = "");
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
|
@ -15,247 +15,151 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
// =============================================================================
|
||||
// Unified Execution APIs for Eager and tracing backends.
|
||||
// =============================================================================
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
|
||||
|
||||
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_ExecutionContext* ctx,
|
||||
TF_Status* s);
|
||||
struct TF_ExecutionContext {
|
||||
explicit TF_ExecutionContext() {}
|
||||
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
|
||||
ExecuteOperation execution_callback;
|
||||
};
|
||||
|
||||
struct TF_AbstractTensor {
|
||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||
};
|
||||
|
||||
struct TF_AbstractOp {
|
||||
string op_type;
|
||||
string op_name;
|
||||
};
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext() {
|
||||
return new TF_ExecutionContext();
|
||||
static FactoriesMap& GetFactories() {
|
||||
static FactoriesMap* factories = new FactoriesMap;
|
||||
return *factories;
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
||||
static const char* default_factory = "<unset>";
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp() {
|
||||
TF_AbstractOp* op = new TF_AbstractOp;
|
||||
return op;
|
||||
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
assert((!GetFactories().count(name)) ||
|
||||
(GetFactories()[name] == factory) &&
|
||||
"Duplicate tracing factory registration");
|
||||
GetFactories()[name] = factory;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||
return t;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||
|
||||
struct TF_GraphContext {
|
||||
TF_Graph* graph;
|
||||
// TODO(srbs): Handle captures.
|
||||
};
|
||||
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
|
||||
auto ctx = new TF_GraphContext;
|
||||
ctx->graph = g;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
|
||||
|
||||
struct TF_GraphTensor {
|
||||
TF_Output output;
|
||||
TF_GraphContext* ctx;
|
||||
};
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
|
||||
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
TF_GraphTensor* t = new TF_GraphTensor;
|
||||
t->output = output;
|
||||
t->ctx = ctx;
|
||||
return t;
|
||||
auto entry = GetFactories().find(default_factory);
|
||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||
string msg = absl::StrCat(
|
||||
"No tracing engine factory has been registered with the key '",
|
||||
default_factory, "' (available: ");
|
||||
// Ensure deterministic (sorted) order in the error message
|
||||
std::set<string> factories_sorted;
|
||||
for (const auto& factory : GetFactories())
|
||||
factories_sorted.insert(factory.first);
|
||||
const char* comma = "";
|
||||
for (const string& factory : factories_sorted) {
|
||||
msg += comma + factory;
|
||||
comma = ", ";
|
||||
}
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
|
||||
return t->output;
|
||||
}
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
msg += ")";
|
||||
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TFE_TensorHandle*>(at->t);
|
||||
}
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an graph tensor handle.");
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TF_GraphTensor*>(at->t);
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
//
|
||||
// These are only the generic entry points for the C API. This file does not
|
||||
// have any visibility into the graph/eager implementation and is only providing
|
||||
// C bindings to the abstract classes defined in the
|
||||
// c_api_unified_experimental_internal.h header.
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
void TF_SetTracingImplementation(const char* name) {
|
||||
tensorflow::internal::SetDefaultTracingEngine(name);
|
||||
}
|
||||
|
||||
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
||||
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||
// given tracing context.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
||||
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
|
||||
}
|
||||
|
||||
struct TF_OutputList {
|
||||
std::vector<TF_AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList* outputs, TF_Status* s) {
|
||||
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
return func;
|
||||
}
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
return wrap(unwrap(func)->AddParameter(dtype, s));
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
return wrap(unwrap(c)->CreateOperation());
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
TF_Status* s) {
|
||||
o->expected_num_outputs = num_outputs;
|
||||
unwrap(o)->expected_num_outputs = num_outputs;
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
return unwrap(o)->outputs.size();
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return o->outputs[i];
|
||||
return wrap(unwrap(o)->outputs[i]);
|
||||
}
|
||||
|
||||
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
auto* tfe_op =
|
||||
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (!IsEagerTensor(inputs[i])) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
TFE_DeleteOp(tfe_op);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
t->t = retvals[i];
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||
}
|
||||
|
||||
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
|
||||
TF_Graph* g = graph_ctx->graph;
|
||||
auto* tf_opdesc =
|
||||
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* input = inputs[i];
|
||||
if (IsEagerTensor(input)) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (GetGraphContext(input) != graph_ctx) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
t->t = output_t;
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context,
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status* s) {
|
||||
context->ctx = eager_context;
|
||||
context->execution_callback = &ExecuteOperationEager;
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = graph_context;
|
||||
context->execution_callback = &ExecuteOperationGraph;
|
||||
unwrap(o)->outputs.push_back(unwrap(tensor));
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
op->op_type = op_type;
|
||||
unwrap(op)->SetOpType(op_type, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s) {
|
||||
op->op_name = op_name;
|
||||
unwrap(op)->SetOpName(op_name, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s) {
|
||||
unwrap(op)->SetAttrType(attr_name, value, s);
|
||||
}
|
||||
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
|
||||
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
|
||||
unwrap(o), s);
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||
delete unwrap(func);
|
||||
}
|
||||
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||
TF_AbstractFunction* func,
|
||||
TF_Status* s) {
|
||||
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
||||
}
|
||||
|
@ -15,8 +15,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -34,39 +35,45 @@ extern "C" {
|
||||
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
|
||||
// of gradient tapes, etc.
|
||||
typedef struct TF_ExecutionContext TF_ExecutionContext;
|
||||
|
||||
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
|
||||
// type of eager and graph tensors.
|
||||
// type of eager and graph tensors. It is also the result of executing an
|
||||
// operation.
|
||||
typedef struct TF_AbstractTensor TF_AbstractTensor;
|
||||
|
||||
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
|
||||
// could contain the op type and other attributes.
|
||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext();
|
||||
// Stores a function representation that can be used for execution or for
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
|
||||
// This allows the client to swap the implementation of the tracing engine.
|
||||
// Any future call to TF_CreateFunction will use the implementation defined
|
||||
// here.
|
||||
void TF_SetTracingImplementation(const char* name);
|
||||
|
||||
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||
// tracing, a function can be obtained by TF_FinalizeFunction.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
|
||||
|
||||
// Creates a context for eager execution of operations.
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp();
|
||||
// Add a new parameter to a TensorFlow Function.
|
||||
// TODO(aminim): what about shape?
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs for Eager and graph modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Keeps track of the current graph and other state e.g. captures etc.
|
||||
typedef struct TF_GraphContext TF_GraphContext;
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
|
||||
void TF_DeleteGraphContext(TF_GraphContext*);
|
||||
|
||||
// `eager_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context, TF_Status*);
|
||||
// `graph_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status*);
|
||||
|
||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||
// `op_type` must outlive `op`.
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
@ -74,44 +81,64 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
// `op_name` must outlive `op`.
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s);
|
||||
// `attr_name` must outlive `op`.
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s);
|
||||
|
||||
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
|
||||
typedef struct TF_GraphTensor TF_GraphTensor;
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
|
||||
TF_Status* s);
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t);
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s);
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||
// an operation, or provided to create a function.
|
||||
// When executing an operation in an eager context, the expected number of
|
||||
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
// Prepare tracing to the expected number of output for an operation.
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
|
||||
// Return the number of outputs in the list.
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
// Return the `i`th output in the list.
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
// Append a tensor at the end of the output list, growing its size by one.
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status*);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph, and after
|
||||
// execution/node creation it'll go and record things that happened in any tape
|
||||
// which happens to be active.
|
||||
// capture some inputs and then add a node in the graph. The output tensors are
|
||||
// returned through the provided TF_OutputList.
|
||||
// Any active tape will observe the effects of this execution.
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||
// context. The provided `ctx` is consumed by this API call and deleted.
|
||||
// The returned TF_AbstractFunction must be deleted by the client,
|
||||
// TODO(aminim): clarify the contract on the state of the context after this
|
||||
// call.
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList*, TF_Status*);
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
|
||||
// Register the function with the given context. This is particularly useful for
|
||||
// making a function available to an eager context.
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
||||
TF_AbstractFunction*, TF_Status*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs specific to Eager modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||
// tensors and how to get value out of eager abstract tensors.
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
194
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
194
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
@ -0,0 +1,194 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Simple wrapper over a TFE_TensorHandle
|
||||
struct EagerTensor : public AbstractTensor {
|
||||
TFE_TensorHandle* t = nullptr;
|
||||
EagerTensor() : AbstractTensor(kKind) {}
|
||||
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
|
||||
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
|
||||
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
||||
};
|
||||
|
||||
// Simple wrapper over a TFE_Op
|
||||
class EagerOp : public AbstractOp {
|
||||
public:
|
||||
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
// Name is ignored in eager mode.
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (op_ == nullptr) {
|
||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||
"op_type must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TFE_OpSetAttrType(op_, attr_name, value);
|
||||
}
|
||||
|
||||
~EagerOp() override { TFE_DeleteOp(op_); }
|
||||
static constexpr AbstractOpKind kKind = kEagerOp;
|
||||
|
||||
private:
|
||||
friend class EagerContext; // For access to op_.
|
||||
TFE_Op* op_ = nullptr;
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
|
||||
class EagerContext : public ExecutionContext {
|
||||
public:
|
||||
EagerContext() : ExecutionContext(kKind) {}
|
||||
|
||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||
eager_ctx_ = TFE_NewContext(options, status);
|
||||
}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new EagerOp(eager_ctx_);
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* eager_op = dyncast<EagerOp>(op);
|
||||
if (eager_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_EagerOp.");
|
||||
return;
|
||||
}
|
||||
auto* tfe_op = eager_op->op_;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
|
||||
if (!eager_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
o->outputs.push_back(new EagerTensor(retvals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't add function parameter on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't use finalize function on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
return;
|
||||
}
|
||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
||||
}
|
||||
|
||||
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kEagerContext;
|
||||
|
||||
private:
|
||||
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
|
||||
TF_ExecutionContext* ctx);
|
||||
TFE_Context* eager_ctx_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Eager API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::dyncast;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||
TF_Status* s) {
|
||||
auto* ctx = new tensorflow::internal::EagerContext();
|
||||
ctx->Build(options, s);
|
||||
return wrap(ctx);
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::EagerTensor(t));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
|
||||
if (!eager_tensor) {
|
||||
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return eager_tensor->t;
|
||||
}
|
||||
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
|
||||
if (!eager_ctx) return nullptr;
|
||||
return eager_ctx->eager_ctx_;
|
||||
}
|
235
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
235
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
@ -0,0 +1,235 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
class GraphContext;
|
||||
|
||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||
// into the list of outputs for the operation.
|
||||
struct GraphTensor : public AbstractTensor {
|
||||
TF_Output output{};
|
||||
GraphContext* ctx = nullptr;
|
||||
GraphTensor() : AbstractTensor(kKind) {}
|
||||
GraphTensor(TF_Output output, GraphContext* ctx)
|
||||
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
||||
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
||||
};
|
||||
|
||||
// GraphOp wraps and populate a TF_OperationDescription.
|
||||
class GraphOp : public AbstractOp {
|
||||
public:
|
||||
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpType called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_name_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||
op_name_ = nullptr;
|
||||
} else {
|
||||
op_type_ = op_type;
|
||||
}
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpName called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_type_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||
op_type_ = nullptr;
|
||||
} else {
|
||||
op_name_ = op_name;
|
||||
}
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (!op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
"op_type and op_name must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TF_SetAttrType(op_.get(), attr_name, value);
|
||||
}
|
||||
~GraphOp() override {}
|
||||
|
||||
static constexpr AbstractOpKind kKind = kGraphOp;
|
||||
|
||||
private:
|
||||
friend class GraphContext; // For access to op_.
|
||||
TF_Graph* g_;
|
||||
std::unique_ptr<TF_OperationDescription> op_;
|
||||
// Hold `op_type` and `op_name` till both are available since we need both
|
||||
// to build a graph operation.
|
||||
const char* op_type_ = nullptr;
|
||||
const char* op_name_ = nullptr;
|
||||
};
|
||||
|
||||
// GraphFunction is a thin wrapper over a TF_Function.
|
||||
struct GraphFunction : public AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
GraphFunction() : AbstractFunction(kKind) {}
|
||||
explicit GraphFunction(TF_Function* func)
|
||||
: AbstractFunction(kKind), func(func) {}
|
||||
~GraphFunction() override {
|
||||
if (func) TF_DeleteFunction(func);
|
||||
}
|
||||
|
||||
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
||||
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
};
|
||||
|
||||
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
||||
// "execution" of operation, i.e. adding them to the function.
|
||||
class GraphContext : public ExecutionContext {
|
||||
public:
|
||||
explicit GraphContext(const char* name)
|
||||
: ExecutionContext(kKind),
|
||||
graph_(new TF_Graph(), TF_DeleteGraph),
|
||||
name_(name) {}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new GraphOp(graph_.get());
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* graph_op = dyncast<GraphOp>(op);
|
||||
if (graph_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
if (tf_opdesc == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||
if (!graph_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (graph_tensor->ctx != this) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, graph_tensor->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||
graph_op->op_ = nullptr;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_OperationDescription* opdesc =
|
||||
TF_NewOperation(graph_.get(), "Placeholder",
|
||||
absl::StrCat("_input_", inputs_.size()).c_str());
|
||||
TF_SetAttrType(opdesc, "dtype", dtype);
|
||||
auto* operation = TF_FinishOperation(opdesc, s);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
|
||||
inputs_.push_back(TF_Output{operation, 0});
|
||||
return new GraphTensor(inputs_.back(), this);
|
||||
}
|
||||
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.reserve(outputs->outputs.size());
|
||||
for (AbstractTensor* abstract_output : outputs->outputs) {
|
||||
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
|
||||
if (!output) {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Returning a non-graph tensor from a function has not "
|
||||
"been implemented yet.");
|
||||
return nullptr;
|
||||
}
|
||||
graph_outputs.push_back(output->output);
|
||||
}
|
||||
|
||||
func->func = TF_GraphToFunction(
|
||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
return func.release();
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
|
||||
~GraphContext() override {}
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kGraphContext;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
std::vector<TF_Output> inputs_;
|
||||
const char* name_;
|
||||
};
|
||||
|
||||
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
return new GraphContext(name);
|
||||
}
|
||||
|
||||
// Register the tracing implemented in this file as the default tracing engine.
|
||||
static bool register_tracing = [] {
|
||||
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||
SetDefaultTracingEngine("graphdef");
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
201
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
201
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
@ -0,0 +1,201 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// =============================================================================
|
||||
// Implementation detail for the unified execution APIs for Eager and tracing
|
||||
// backends (graph/MLIR).
|
||||
//
|
||||
// This defines a set of abstract classes that are intended to provide the
|
||||
// functionality of the opaque C types exposed in the public APIs defined in the
|
||||
// `c_api_unified_experimental.h` header.
|
||||
// =============================================================================
|
||||
|
||||
// We can't depend on C++ rtti, but we still want to be able to have a safe
|
||||
// dynamic_cast to provide diagnostics to the user when the API is misused.
|
||||
// Instead we model RTTI by listing all the possible subclasses for each
|
||||
// abstract base. Each subclass initializes the base class with the right
|
||||
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
|
||||
// utility.
|
||||
template <typename T, typename S>
|
||||
T* dyncast(S source) {
|
||||
if (source->getKind() != T::kKind) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::down_cast<T*>(source);
|
||||
}
|
||||
|
||||
// Represents either an EagerTensor or a GraphTensor.
|
||||
// This base class does not expose any public methods other than to distinguish
|
||||
// which subclass it actually is. The user is responsible to use the right
|
||||
// type of AbstractTensor in their context (do not pass an EagerTensor to a
|
||||
// GraphContext and vice-versa).
|
||||
class AbstractTensor {
|
||||
protected:
|
||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractTensorKind getKind() const { return kind_; }
|
||||
virtual ~AbstractTensor() = default;
|
||||
|
||||
private:
|
||||
const AbstractTensorKind kind_;
|
||||
};
|
||||
|
||||
// Represents the results of the execution of an operation.
|
||||
struct OutputList {
|
||||
std::vector<AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
// Holds the result of tracing a function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraphFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Temporary API till we figure the right abstraction for AbstractFunction.
|
||||
// At the moment both Eager and Graph needs access to a "TF_Function" object.
|
||||
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
};
|
||||
|
||||
// An abstract operation describes an operation by its type, name, and
|
||||
// attributes. It can be "executed" by the context with some input tensors.
|
||||
// It is allowed to reusing the same abstract operation for multiple execution
|
||||
// on a given context, with the same or different input tensors.
|
||||
class AbstractOp {
|
||||
protected:
|
||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractOpKind getKind() const { return kind_; }
|
||||
virtual ~AbstractOp() = default;
|
||||
|
||||
// Sets the type of the operation (for example `AddV2`).
|
||||
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
|
||||
|
||||
// Sets the name of the operation: this is an optional identifier that is
|
||||
// not intended to carry semantics and preserved/propagated without
|
||||
// guarantees.
|
||||
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
|
||||
|
||||
// Add a `TypeAttribute` on the operation.
|
||||
virtual void SetAttrType(const char* attr_name, TF_DataType value,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractOpKind kind_;
|
||||
};
|
||||
|
||||
// This holds the context for the execution: dispatching operations either to an
|
||||
// eager implementation or to a graph implementation.
|
||||
struct ExecutionContext {
|
||||
protected:
|
||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
ExecutionContextKind getKind() const { return k; }
|
||||
virtual ~ExecutionContext() = default;
|
||||
|
||||
// Executes the operation on the provided inputs and populate the OutputList
|
||||
// with the results. The input tensors must match the current context.
|
||||
// The effect of "executing" an operation depends on the context: in an Eager
|
||||
// context it will dispatch it to the runtime for execution, while in a
|
||||
// tracing context it will add the operation to the current function.
|
||||
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
// Creates an empty AbstractOperation suitable to use with this context.
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
|
||||
|
||||
// Registers a functions with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
void SetDefaultTracingEngine(const char* name);
|
||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||
FactoryFunction factory);
|
||||
|
||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||
// C++ implementation, and back.
|
||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<const CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<C_TYPEDEF* const&>(o); \
|
||||
} \
|
||||
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
|
||||
}
|
||||
|
||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
@ -15,44 +15,44 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include <string.h>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
|
||||
protected:
|
||||
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
|
||||
};
|
||||
|
||||
TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -69,7 +69,6 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
TFE_DeleteTensorHandle(t);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
@ -83,100 +82,75 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TFE_DeleteTensorHandle(result_t);
|
||||
TF_DeleteOutputList(o);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
// Enter a graph context.
|
||||
TF_Graph* g = TF_NewGraph();
|
||||
TF_GraphContext* graph_context = TF_NewGraphContext(g);
|
||||
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "double";
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_CreateFunction(fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
|
||||
auto* operation = TF_FinishOperation(placeholder_op, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output placeholder_t = {operation, 0};
|
||||
TF_GraphTensor* graph_t =
|
||||
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(op, "my_add", status.get());
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {t, t};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(t);
|
||||
TF_DeleteGraphTensor(graph_t);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TF_GraphTensor* result_graph_tensor =
|
||||
TF_AbstractTensorGetGraphTensor(result, status.get());
|
||||
TF_DeleteAbstractTensor(result);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output result_output =
|
||||
TF_GraphTensorToOutput(result_graph_tensor, status.get());
|
||||
TF_DeleteGraphTensor(result_graph_tensor);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
string fn_name = "double";
|
||||
TF_Function* f = TF_GraphToFunction(
|
||||
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
|
||||
nullptr, nullptr, fn_name.c_str(), status.get());
|
||||
TF_AbstractFunction* func =
|
||||
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an eager context to run the function.
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TFE_ContextAddFunction(eager_ctx, f, status.get());
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp();
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||
TF_AbstractTensor* input_t =
|
||||
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0);
|
||||
TFE_TensorHandle* final =
|
||||
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
@ -185,20 +159,325 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
float* f_value = static_cast<float*>(TF_TensorData(f_t));
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TFE_DeleteTensorHandle(input_eager);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TFE_DeleteTensorHandle(final);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteFunction(f);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
|
||||
TF_DeleteGraphContext(graph_context);
|
||||
TF_DeleteGraph(g);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Status* s = status.get();
|
||||
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "two_adds";
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
TF_AbstractTensor* add_output1;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output1 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Same with a second "Add" computing `arg1 + arg1`.
|
||||
TF_AbstractTensor* add_output2;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output2 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Finalize the function by providing the returned values.
|
||||
TF_AbstractFunction* func;
|
||||
{
|
||||
// We want to return the output of both add operations, create a new list
|
||||
// and populate it.
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListPushBack(func_outputs, add_output1, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_OutputListPushBack(func_outputs, add_output2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* We traced so far this function:
|
||||
*
|
||||
* def two_adds(a, b):
|
||||
* my_add1 = a + b
|
||||
* my_add2 = b + b
|
||||
* return my_add1, my_add2
|
||||
*
|
||||
* Now we will execute this function with an eager context:
|
||||
*
|
||||
* output1, output2 = two_adds(2.0, 3.0)
|
||||
*
|
||||
* and check that we got 5.0 and 6.0 as results.
|
||||
*/
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build two abstract input tensors as function arguments.
|
||||
std::vector<TF_AbstractTensor*> func_args;
|
||||
{
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
}
|
||||
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||
eager_execution_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||
|
||||
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
|
||||
float results[2];
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
|
||||
TF_DeleteTensor(f_t);
|
||||
}
|
||||
ASSERT_EQ(results[0], 5.0);
|
||||
ASSERT_EQ(results[1], 6.0);
|
||||
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
}
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, func);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// This should fail.
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// This should fail.
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an Eager operation.
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at, at};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -17,9 +17,11 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
@ -57,16 +59,51 @@ class AbstractContextInterface {
|
||||
virtual AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
|
||||
|
||||
// Create a tensor instance from the given data buffer and description.
|
||||
// `memory_releaser` will be called on destruction, and it's responsible for
|
||||
// cleaning up the underlying buffer. `convert_string` indicates whether it
|
||||
// has to handle tstring conversion. Expected to be removed once tstring
|
||||
// migration is done.
|
||||
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
|
||||
const int64_t* dims,
|
||||
int num_dims, void* data,
|
||||
size_t len, bool convert_string,
|
||||
MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
// Copy the handle to another device.
|
||||
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
Status* status) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
|
||||
// Load a SavedModelAPI object from the given directory and tags
|
||||
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
@ -25,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -176,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
@ -226,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||
<< "Execution should fail because the variable is being used on the "
|
||||
"wrong device.";
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
ASSERT_EQ(
|
||||
tensorflow::string(name),
|
||||
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
|
||||
TFE_DeleteTensorHandle(var_value);
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
@ -246,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
|
||||
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
|
||||
ASSERT_FALSE(arrived);
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
arrived = false;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Base case: two CPU inputs executes fine.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
|
||||
TFE_TensorHandle* retval;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in same custom device works.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in different custom devices fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||
|
||||
// Custom device: mix of custom/physical fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -16,8 +16,10 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
@ -41,15 +43,15 @@ struct TfDlManagedTensorCtx {
|
||||
|
||||
// Gets tensor from eager tensor handle.
|
||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (handle->Type() != TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
"DLPack doesn't support ", handle->TypeString(), " tensor");
|
||||
return nullptr;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||
// Gets DLPack's DLContext from eager tensor handle.
|
||||
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = h->handle->DeviceName(&status->status);
|
||||
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
|
@ -42,7 +42,28 @@ class AbstractOperationInterface {
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
|
||||
// Returns the operation's device name.
|
||||
//
|
||||
// The value returned may be different from the one set by SetDeviceName, but
|
||||
// it will be compatible with it: the name will be updated by device placement
|
||||
// logic to refer to the specific device chosen.
|
||||
//
|
||||
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
|
||||
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
|
||||
// chosen for the operation by the device placement logic in the
|
||||
// executor. After that, the value returned by DeviceName will be a full
|
||||
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
|
||||
virtual const string& DeviceName() const = 0;
|
||||
|
||||
// Sets the operation device name.
|
||||
//
|
||||
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
|
||||
// the result will be used as a constraint for device placement. See the
|
||||
// documentation for DeviceName for more details.
|
||||
//
|
||||
// The value will override the previous value - that is, no "merging" of
|
||||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
|
@ -7,10 +7,27 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Currently pybind extension shared objects must use only C API headers since
|
||||
# the C API has static initializers duplicated in the Python bindings. So we
|
||||
# need a second rule that omits .cc files, in
|
||||
# tensorflow/python:_pywrap_parallel_device.
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = ["parallel_device.h"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = ["parallel_device.cc"],
|
||||
hdrs = ["parallel_device.h"],
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -27,6 +44,7 @@ tf_cc_test(
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -36,3 +54,19 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||
# to TensorFlow by default, just used in a few tests.
|
||||
filegroup(
|
||||
name = "parallel_device_ops_srcs",
|
||||
srcs = ["parallel_device_ops.cc"],
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_ops",
|
||||
srcs = [":parallel_device_ops_srcs"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -92,6 +92,10 @@ class ParallelDevice {
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
@ -574,23 +625,21 @@ void DeleteParallelDevice(void* device_info) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
custom_device.delete_device = &DeleteParallelDevice;
|
||||
custom_device.execute = &ParallelDeviceExecute;
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info) {
|
||||
device->copy_tensor_to_device = &CopyToParallelDevice;
|
||||
device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
device->delete_device = &DeleteParallelDevice;
|
||||
device->execute = &ParallelDeviceExecute;
|
||||
std::vector<std::string> underlying_devices_vector;
|
||||
underlying_devices_vector.reserve(num_underlying_devices);
|
||||
for (int device_index = 0; device_index < num_underlying_devices;
|
||||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
ParallelDevice* d =
|
||||
new ParallelDevice(device_name, underlying_devices_vector);
|
||||
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
|
||||
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// Register a parallel device named `device_name` which forwards operations to
|
||||
// Allocate a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
// on each underlying device.
|
||||
//
|
||||
@ -50,11 +52,12 @@ namespace eager {
|
||||
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
|
||||
// into its components.
|
||||
//
|
||||
// `context` owns the parallel device. `underlying_devices` must stay valid
|
||||
// while the parallel device is in use.
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status);
|
||||
// The filled `device` struct and the allocated `device_info` struct may be
|
||||
// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match.
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||
// it to core TF. Right now the eager C API does some checking of op
|
||||
// registrations before calling into custom devices, but we may be able to avoid
|
||||
// that.
|
||||
REGISTER_OP("DeviceID")
|
||||
.Output("device_id: int64")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -278,14 +278,28 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
EXPECT_EQ(expected_value,
|
||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
@ -297,9 +311,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context, device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
@ -331,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
@ -361,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
@ -371,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
// Compute the device ID twice and verify the result
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
@ -456,16 +495,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
underlying_devices.push_back(first_device_name);
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
underlying_devices.push_back(second_device_name);
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{first_device_name,
|
||||
second_device_name};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
|
||||
@ -488,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
@ -524,12 +561,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create two vectors with different lengths
|
||||
@ -570,24 +606,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
// Create a parallel device with two CPUs
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> first_underlying_devices{
|
||||
std::array<const char*, 2> first_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), first_device_name, first_underlying_devices.data(),
|
||||
first_underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context.get(), first_device_name,
|
||||
first_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a second parallel device with the first parallel device and one
|
||||
// additional CPU.
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
std::vector<const char*> second_underlying_devices{
|
||||
std::array<const char*, 2> second_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), second_device_name, second_underlying_devices.data(),
|
||||
second_underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context.get(), second_device_name,
|
||||
second_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the first parallel device
|
||||
@ -623,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
ExpectScalarEq<float>(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
@ -637,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
ExpectScalarEq<float>(first_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
@ -656,11 +690,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) {
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 1> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
@ -775,12 +808,11 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the parallel device
|
||||
@ -801,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
@ -867,12 +899,11 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* function_name = "test_reduce_mul";
|
||||
@ -905,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
|
24
tensorflow/c/eager/tfe_cancellation_manager_internal.h
Normal file
24
tensorflow/c/eager/tfe_cancellation_manager_internal.h
Normal file
@ -0,0 +1,24 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
||||
struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
35
tensorflow/c/eager/tfe_context_internal.h
Normal file
35
tensorflow/c/eager/tfe_context_internal.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying context object. Instead, call
|
||||
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
||||
// the TFE_Context structure.
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
37
tensorflow/c/eager/tfe_executor_internal.h
Normal file
37
tensorflow/c/eager/tfe_executor_internal.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
|
||||
struct TFE_Executor {
|
||||
explicit TFE_Executor(bool async)
|
||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||
|
||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||
|
||||
tensorflow::EagerExecutor* executor() {
|
||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
146
tensorflow/c/eager/tfe_monitoring_internal.h
Normal file
146
tensorflow/c/eager/tfe_monitoring_internal.h
Normal file
@ -0,0 +1,146 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringCounter {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringCounter(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||
};
|
||||
struct TFE_MonitoringStringGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||
};
|
||||
struct TFE_MonitoringBoolGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||
};
|
||||
|
||||
template <typename ValueType, int NumLabels>
|
||||
struct TFE_MonitoringGauge {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringGauge(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
gauge = absl::WrapUnique(
|
||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
explicit TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
create_buckets;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSamplerCell {
|
||||
tensorflow::monitoring::SamplerCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringSampler {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringSampler(
|
||||
const char* name,
|
||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||
const char* description, LabelDesc&&... label) {
|
||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||
{name, description, label...}, std::move(buckets)));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
39
tensorflow/c/eager/tfe_op_attrs_internal.h
Normal file
39
tensorflow/c/eager/tfe_op_attrs_internal.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
|
||||
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
const char* attr_name, TF_Status* status);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
36
tensorflow/c/eager/tfe_op_internal.h
Normal file
36
tensorflow/c/eager/tfe_op_internal.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying operation object. Instead, call
|
||||
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
||||
// the TFE_Op structure.
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
30
tensorflow/c/eager/tfe_tensor_debug_info_internal.h
Normal file
30
tensorflow/c/eager/tfe_tensor_debug_info_internal.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
: dev_dims(dims) {}
|
||||
|
||||
// Fully-padded, minor-to-major.
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
38
tensorflow/c/eager/tfe_tensorhandle_internal.h
Normal file
38
tensorflow/c/eager/tfe_tensorhandle_internal.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying handle object. Instead, call
|
||||
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
||||
// the TFE_TensorHandle structure.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
@ -85,17 +85,36 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
|
||||
const std::string test_name = tensorflow::str_util::StringReplace(
|
||||
::testing::UnitTest::GetInstance()->current_test_info()->name(), "/",
|
||||
"_", /*replace_all=*/true);
|
||||
if (!cloud_path_.empty()) {
|
||||
// We have to join path for non-local filesystem manually to make sure
|
||||
// that this test will run on Windows since `tensorflow::io::JoinPath`
|
||||
// behaves differently on Windows. `tmp_dir` should be something like
|
||||
// `path/to/tmp/dir/`. After joining path, we will have
|
||||
// /path/to/tmp/dir/tf_fs_rng_name/`
|
||||
root_dir_ = tensorflow::strings::StrCat(
|
||||
"/", tmp_dir_,
|
||||
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name), "/");
|
||||
} else {
|
||||
root_dir_ = tensorflow::io::JoinPath(
|
||||
::testing::TempDir(),
|
||||
tmp_dir_,
|
||||
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
|
||||
}
|
||||
if (!GetParam().empty()) {
|
||||
root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_,
|
||||
root_dir_);
|
||||
}
|
||||
env_ = Env::Default();
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
if (mkdir(root_dir_.c_str(), 0755) != 0) {
|
||||
int error_code = errno;
|
||||
GTEST_SKIP() << "Cannot create working directory: "
|
||||
<< tensorflow::IOError(root_dir_, error_code);
|
||||
FileSystem* fs = nullptr;
|
||||
Status s = env_->GetFileSystemForFile(root_dir_, &fs);
|
||||
if (fs == nullptr || !s.ok())
|
||||
GTEST_SKIP() << "No filesystem registered: " << s;
|
||||
|
||||
s = fs->CreateDir(root_dir_);
|
||||
if (!s.ok()) {
|
||||
GTEST_SKIP() << "Cannot create working directory: " << s;
|
||||
}
|
||||
}
|
||||
|
||||
@ -115,9 +134,10 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
|
||||
std::string GetURIForPath(StringPiece path) {
|
||||
const std::string translated_name =
|
||||
tensorflow::io::JoinPath(root_dir_, path);
|
||||
if (GetParam().empty()) return translated_name;
|
||||
|
||||
return tensorflow::strings::StrCat(GetParam(), "://", translated_name);
|
||||
// We have already checked `GetParam().empty()` in
|
||||
// `ModularFileSystemTest()`. root_dir_ should contain `GetParam() + "://"`
|
||||
// if it isn't empty.
|
||||
return translated_name;
|
||||
}
|
||||
|
||||
// Converts absolute paths to paths relative to root_dir_.
|
||||
@ -133,15 +153,28 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
|
||||
rng_val_ = distribution(gen);
|
||||
}
|
||||
|
||||
static void SetCloudPath(const std::string& cloud_path) {
|
||||
cloud_path_ = cloud_path;
|
||||
if (cloud_path_.back() == '/') cloud_path_.pop_back();
|
||||
}
|
||||
|
||||
static void SetTmpDir(const std::string& tmp_dir) {
|
||||
tmp_dir_ = tmp_dir.empty() ? ::testing::TempDir() : tmp_dir;
|
||||
}
|
||||
|
||||
protected:
|
||||
Env* env_;
|
||||
|
||||
private:
|
||||
std::string root_dir_;
|
||||
static int rng_val_;
|
||||
static std::string cloud_path_;
|
||||
static std::string tmp_dir_;
|
||||
};
|
||||
|
||||
int ModularFileSystemTest::rng_val_;
|
||||
std::string ModularFileSystemTest::cloud_path_;
|
||||
std::string ModularFileSystemTest::tmp_dir_;
|
||||
|
||||
// As some of the implementations might be missing, the tests should still pass
|
||||
// if the returned `Status` signals the unimplemented state.
|
||||
@ -1729,6 +1762,20 @@ static bool GetURIScheme(const std::string& scheme) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// This function is used for cloud filesystem
|
||||
// `S3` and `GCS` require the `root_dir_` to have bucket name
|
||||
// `HDFS` requires the `root_dir` to have namenode
|
||||
// `root_dir_ = scheme + "://" cloud_path_ + root_dir_`
|
||||
static bool SetCloudPath(const std::string& cloud_path_) {
|
||||
ModularFileSystemTest::SetCloudPath(cloud_path_);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool SetTmpDir(const std::string& tmp_dir_) {
|
||||
ModularFileSystemTest::SetTmpDir(tmp_dir_);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -1741,7 +1788,12 @@ GTEST_API_ int main(int argc, char** argv) {
|
||||
tensorflow::Flag("dso", tensorflow::LoadDSO, "",
|
||||
"Path to shared object to load"),
|
||||
tensorflow::Flag("scheme", tensorflow::GetURIScheme, "",
|
||||
"URI scheme to test")};
|
||||
"URI scheme to test"),
|
||||
tensorflow::Flag("cloud_path", tensorflow::SetCloudPath, "",
|
||||
"Path for cloud filesystem (namenode for hdfs, "
|
||||
"bucketname for s3/gcs)"),
|
||||
tensorflow::Flag("tmp_dir", tensorflow::SetTmpDir, "",
|
||||
"Temporary directory to store test data.")};
|
||||
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
|
||||
std::cout << tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
return -1;
|
||||
|
66
tensorflow/c/experimental/saved_model/README.md
Normal file
66
tensorflow/c/experimental/saved_model/README.md
Normal file
@ -0,0 +1,66 @@
|
||||
# Tensorflow C SavedModel API
|
||||
|
||||
## Overview
|
||||
|
||||
These are the new experimental C SavedModel APIs for loading and running
|
||||
SavedModels in a TF2-idiomatic fashion. See
|
||||
[RFC 207](https://github.com/tensorflow/community/pull/207) for additional
|
||||
context.
|
||||
|
||||
The directory structure is as follows:
|
||||
|
||||
```none
|
||||
saved_model/
|
||||
|
||||
public/
|
||||
|
||||
internal/
|
||||
|
||||
core/
|
||||
|
||||
```
|
||||
|
||||
## saved_model/public
|
||||
|
||||
`saved_model/public` is intended to house *only the public headers* of the
|
||||
SavedModel C API.
|
||||
|
||||
These headers:
|
||||
|
||||
1. declare opaque C types (like `TF_SavedModel`),
|
||||
|
||||
2. declare the functions that operate on these types (like `TF_LoadSavedModel`).
|
||||
|
||||
Once they leave experimental, these APIs should be considered stable for use
|
||||
by external clients.
|
||||
|
||||
These headers are in a separate directory to make it obvious to clients which
|
||||
headers they should depend on, and which headers are implementation details.
|
||||
Separating these public headers by directory also allow future programmatic
|
||||
checks to ensure that TF public headers only `#include` other public TF headers.
|
||||
|
||||
## saved_model/internal
|
||||
|
||||
`saved_model/internal` is the "glue" between the C API and the internal C++
|
||||
implementation.
|
||||
|
||||
Its role is to:
|
||||
|
||||
1. implement the C API functions declared in `saved_model/public`
|
||||
|
||||
2. define the C API types declared in `saved_model/public`
|
||||
|
||||
The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while
|
||||
the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`).
|
||||
|
||||
The headers exposing the internal implementation of the opaque C types are only
|
||||
visible to other implementors of the C API. This is similar to how other
|
||||
TF C API implementations use `tf_status_internal.h` (to extract the underlying
|
||||
`tensorflow::Status`). All other targets in this directory are private.
|
||||
|
||||
## saved_model/core
|
||||
|
||||
`saved_model/core` contains pure C++ "Classes" underlying the C API types
|
||||
in `saved_model/public/`. These are implementation
|
||||
details subject to change, and have limited visibility to implementors only.
|
||||
This is the bottom-most layer of the `C++ -> C -> C++` sandwich.
|
85
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
85
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
@ -0,0 +1,85 @@
|
||||
# Experimental SavedModel C APIs for TensorFlow. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||
# have visibility limited to Tensorflow's implementation only.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
|
||||
"//tensorflow/core:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
hdrs = [
|
||||
"function_metadata.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
hdrs = [
|
||||
"saved_model_api.h",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
"tf_saved_model_impl.cc",
|
||||
],
|
||||
hdrs = ["tf_saved_model_impl.h"],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pywrap_required_hdrs",
|
||||
textual_hdrs = [
|
||||
"concrete_function.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mobile_srcs_only_runtime",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
"concrete_function.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"tf_saved_model_impl.cc",
|
||||
"tf_saved_model_impl.h",
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||
ConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
||||
const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note that ConcreteFunctions's lifetimes are effectively bound
|
||||
// to the SavedModel they are loaded from, since they retain pointers
|
||||
// to the TensorHandles owned by the SavedModel, and the FunctionDef
|
||||
// of the SavedModel.
|
||||
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||
// this class. Eventually we want to remove this virtual base class indirection
|
||||
// and have only a single implementation.
|
||||
class ConcreteFunction {
|
||||
public:
|
||||
virtual ~ConcreteFunction() = 0;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual AbstractOperationInterface* GetCallOp() = 0;
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||
FunctionDef* function_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FunctionMetadata {
|
||||
// TODO(bmzhao): Fill in with fields as necessary
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||
// this class. Eventually we want to remove this virtual base class indirection
|
||||
// and have only a single implementation.
|
||||
class SavedModelAPI {
|
||||
public:
|
||||
// Retrieve a function from the TF2 SavedModel, using the "path" to a function
|
||||
// in a TF2 savedmodel.
|
||||
// Note: `function` is a double pointer, so that implementations are
|
||||
// able to return a pointer to an internal member.
|
||||
virtual Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) = 0;
|
||||
|
||||
// Retrieve a function from a SavedModel, using the key of the
|
||||
// SignatureDef map:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) = 0;
|
||||
|
||||
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
|
||||
|
||||
virtual ~SavedModelAPI() = default;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
@ -0,0 +1,60 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status TFSavedModelAPIImpl::GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
}
|
||||
|
||||
Status TFSavedModelAPIImpl::GetSignatureDefFunction(
|
||||
const std::string& signature_def_key, ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
}
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(functions_.size());
|
||||
for (ConcreteFunction& function : functions_) {
|
||||
result.push_back(&function);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status TFSavedModelAPIImpl::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
public:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
|
||||
Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPIImpl() override = default;
|
||||
|
||||
private:
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
212
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
212
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
@ -0,0 +1,212 @@
|
||||
# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
# External clients should not worry about this directory; all contents are implementation details.
|
||||
# Code in this directory is intended to form the glue between the C API and the internal C++
|
||||
# implementation by
|
||||
# 1. mapping C API calls onto correponding methods of C++ objects
|
||||
# 2. mapping opaque C types onto C++ classes
|
||||
|
||||
# Note(bmzhao): The *.cc files in this directory form the direct implementation of the
|
||||
# C API functions exposed in tf/c/experimental/saved_model/public/.
|
||||
|
||||
# Note(bmzhao): All *type.h files in this directory are the internal definitions of
|
||||
# the opaque C types. These headers should only be visible to internal tensorflow
|
||||
# implementors.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function_type",
|
||||
":function_metadata",
|
||||
":function_metadata_type",
|
||||
":tensorhandle_list",
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list",
|
||||
srcs = [
|
||||
"concrete_function_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list_type",
|
||||
":concrete_function_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list_type",
|
||||
hdrs = [
|
||||
"concrete_function_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_type",
|
||||
hdrs = [
|
||||
"concrete_function_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
srcs = [
|
||||
"function_metadata.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:function_metadata.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata_type",
|
||||
hdrs = [
|
||||
"function_metadata_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
srcs = [
|
||||
"saved_model_api.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
":concrete_function_list_type",
|
||||
":concrete_function_type",
|
||||
":saved_model_api_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api_type",
|
||||
hdrs = [
|
||||
"saved_model_api_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list",
|
||||
srcs = [
|
||||
"tensorhandle_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list_type",
|
||||
hdrs = [
|
||||
"tensorhandle_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(const_cast<tensorflow::FunctionMetadata*>(
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) {
|
||||
return list->list.size();
|
||||
}
|
||||
|
||||
TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list,
|
||||
int i) {
|
||||
return tensorflow::wrap(list->list[i]);
|
||||
}
|
||||
|
||||
void TF_DeleteConcreteFunctionList(TF_ConcreteFunctionList* list) {
|
||||
delete list;
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_ConcreteFunctionList {
|
||||
std::vector<tensorflow::ConcreteFunction*> list;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate
|
||||
// struct, since the lifetime of the struct and the raw pointer it wraps would
|
||||
// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*.
|
||||
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%{
|
||||
#include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
|
||||
%}
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
%include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
|
||||
// TODO(bmzhao): Add getter functions here as necessary.
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
}
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
const char* const* tags, int tags_len,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
|
||||
std::unordered_set<std::string> tagset;
|
||||
for (int i = 0; i < tags_len; ++i) {
|
||||
tagset.insert(std::string(tags[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
}
|
||||
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||
|
||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
const char* function_path,
|
||||
TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetFunction(function_path, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,109 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
const char* kServeTag[] = {"serve"};
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
|
||||
kTestData, saved_model_dir);
|
||||
}
|
||||
|
||||
// This value parameterized test allows us to test both TFRT
|
||||
// and non TFRT runtimes.
|
||||
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
|
||||
class CSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
|
||||
::testing::Bool());
|
||||
|
||||
} // namespace
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_SavedModel {
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
|
||||
return tensorflow::unwrap(list)->size();
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
|
||||
int i) {
|
||||
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
|
||||
}
|
||||
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
70
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
70
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
@ -0,0 +1,70 @@
|
||||
# Experimental SavedModel C APIs for TensorFlow.
|
||||
# See RFC https://github.com/tensorflow/community/pull/207
|
||||
# All headers are on the public surface of Tensorflow's C API.
|
||||
# Once moved out of experimental, these will be stable.
|
||||
# The idea behind a separate public/ directory is to make apparent
|
||||
# which headers are part of TF's public interface (and which headers)
|
||||
# are implementation details. This structure allows us to also perform future
|
||||
# programmatic checks that all "public" headers only include other "public"
|
||||
# headers.
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead.
|
||||
# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph.
|
||||
exports_files(
|
||||
[
|
||||
"concrete_function.h",
|
||||
"concrete_function_list.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"tensorhandle_list.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
|
||||
# The purpose of this header is to provide insulation against
|
||||
# future changes where we rename/move a public header, without
|
||||
# forcing all clients to change their "#includes".
|
||||
cc_library(
|
||||
name = "c_saved_model_api",
|
||||
hdrs = ["c_saved_model_api.h"],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
":function_metadata",
|
||||
":saved_model_api",
|
||||
":tensorhandle_list",
|
||||
],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "concrete_function",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "concrete_function_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "function_metadata",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "saved_model_api",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorhandle_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
|
||||
)
|
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that corresponds to a Function loaded from a SavedModel.
|
||||
// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the
|
||||
// C++ Unified Eager/Graph API's AbstractFunction
|
||||
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||
|
||||
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
|
||||
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
|
||||
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a list of TensorHandles implicitly captured by this function.
|
||||
TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function.
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
// Returns the `i`th TF_ConcreteFunction in the list.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_ConcreteFunctionList* list, int i);
|
||||
|
||||
// Deletes `list`.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type used to store any metadata associated with a function.
|
||||
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||
|
||||
// TODO(bmzhao): Add getters for fields as we determine what metadata
|
||||
// we want to expose.
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
108
tensorflow/c/experimental/saved_model/public/saved_model_api.h
Normal file
108
tensorflow/c/experimental/saved_model/public/saved_model_api.h
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type representing a Tensorflow "SavedModel"
|
||||
// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer
|
||||
// to achieve ABI stability.
|
||||
typedef struct TF_SavedModel TF_SavedModel;
|
||||
|
||||
// Load a SavedModel from `dirname`. We expect the SavedModel to contain a
|
||||
// single Metagraph (as for those exported from TF2's `tf.saved_model.save`).
|
||||
//
|
||||
// Params:
|
||||
// dirname - A directory filepath that the SavedModel is at.
|
||||
// ctx - A TFE_Context containing optional load/TF runtime options.
|
||||
// `ctx` must outlive the returned TF_SavedModel pointer.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a newly created
|
||||
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
|
||||
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname,
|
||||
TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Load a SavedModel from `dirname`.
|
||||
//
|
||||
// Params:
|
||||
// dirname - A directory filepath that the SavedModel is at.
|
||||
// ctx - A TFE_Context containing optional load/TF runtime options.
|
||||
// `ctx` must outlive the returned TF_SavedModel pointer.
|
||||
// tags - char* array of SavedModel tags. We will load the metagraph matching
|
||||
// the tags.
|
||||
// tags_len - number of elements in the `tags` array.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a newly created
|
||||
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
|
||||
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModelWithTags(
|
||||
const char* dirname, TFE_Context* ctx, const char* const* tags,
|
||||
int tags_len, TF_Status* status);
|
||||
|
||||
// Deletes a TF_SavedModel, and frees any resources owned by it.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
|
||||
|
||||
// Retrieve a function from the TF2 SavedModel via function path.
|
||||
//
|
||||
// Params:
|
||||
// model - The TF2 SavedModel to load a function from.
|
||||
// function_path - A string containing the path from the root saved python
|
||||
// object to a tf.function method.
|
||||
// TODO(bmzhao): Add a detailed example of this with a
|
||||
// python tf.module before moving this out of experimental.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// TF_ConcreteFunction instance. The lifetime of this instance is
|
||||
// "conceptually" bound to `model`. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
|
||||
TF_SavedModel* model, const char* function_path, TF_Status* status);
|
||||
|
||||
// Retrieve a function from the TF SavedModel via a SignatureDef key.
|
||||
//
|
||||
// Params:
|
||||
// model - The SavedModel to load a function from.
|
||||
// signature_def_key - The string key of the SignatureDef map of a SavedModel:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// TF_ConcreteFunction instance. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
|
||||
|
||||
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||
// The lifetime of the returned list is bound to `model`.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
|
||||
TF_SavedModel* model);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
@ -0,0 +1,43 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
|
||||
const TF_TensorHandleList* list);
|
||||
|
||||
// Returns the `i`th TFE_TensorHandle in the list.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
|
||||
const TF_TensorHandleList* list, int i);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
@ -156,6 +156,7 @@ cc_library(
|
||||
":array_grad",
|
||||
":data_flow_grad",
|
||||
":image_grad",
|
||||
":manip_grad",
|
||||
":math_grad",
|
||||
":nn_grad",
|
||||
],
|
||||
@ -177,10 +178,11 @@ cc_library_with_android_deps(
|
||||
name = "ops",
|
||||
srcs = ["framework/ops.cc"],
|
||||
hdrs = ["framework/ops.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
@ -195,7 +197,7 @@ cc_library_with_android_deps(
|
||||
"framework/scope_internal.h",
|
||||
],
|
||||
hdrs = ["framework/scope.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
],
|
||||
@ -235,7 +237,7 @@ cc_library_with_android_deps(
|
||||
name = "client_session",
|
||||
srcs = ["client/client_session.cc"],
|
||||
hdrs = ["client/client_session.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
@ -273,7 +275,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/const_op.cc"],
|
||||
hdrs = ["ops/const_op.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":ops",
|
||||
@ -302,7 +304,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/while_loop.cc"],
|
||||
hdrs = ["ops/while_loop.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":cc_ops",
|
||||
@ -494,6 +496,32 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "manip_grad",
|
||||
srcs = ["gradients/manip_grad.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":gradients",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gradients_manip_grad_test",
|
||||
srcs = ["gradients/manip_grad_test.cc"],
|
||||
deps = [
|
||||
":array_ops",
|
||||
":cc_ops",
|
||||
":gradient_checker",
|
||||
":manip_grad",
|
||||
":testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "math_ops",
|
||||
|
78
tensorflow/cc/experimental/base/public/BUILD
Normal file
78
tensorflow/cc/experimental/base/public/BUILD
Normal file
@ -0,0 +1,78 @@
|
||||
# Experimental C++ APIs for TensorFlow.
|
||||
# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability.
|
||||
# Users are expected to compile against public c++ headers, and link against
|
||||
# libtensorflow (https://www.tensorflow.org/install/lang_c).
|
||||
# We aim to achieve ABI stability in new C++ APIs by only using types
|
||||
# on the API surface that:
|
||||
# 1. Have a header-only implementation
|
||||
# 2. Are std:: types
|
||||
# 3. Wrap an opaque C type
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime",
|
||||
hdrs = [
|
||||
"runtime.h",
|
||||
],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_builder",
|
||||
hdrs = [
|
||||
"runtime_builder.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "status",
|
||||
hdrs = [
|
||||
"status.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor",
|
||||
hdrs = [
|
||||
"tensor.h",
|
||||
],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle",
|
||||
hdrs = [
|
||||
"tensorhandle.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
":tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
71
tensorflow/cc/experimental/base/public/runtime.h
Normal file
71
tensorflow/cc/experimental/base/public/runtime.h
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||
// resources, threadpools, etc. Clients are expected to construct a Runtime
|
||||
// object through tensorflow::cc::RuntimeBuilder::Build, after setting any
|
||||
// relevant configuration options. Many Tensorflow functions take a reference to
|
||||
// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and
|
||||
// may have different implementations depending on the runtime. For many of
|
||||
// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the
|
||||
// Runtime must outlive these objects.
|
||||
class Runtime {
|
||||
public:
|
||||
// Runtime is movable, but not copyable.
|
||||
Runtime(Runtime&&) = default;
|
||||
Runtime& operator=(Runtime&&) = default;
|
||||
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||
|
||||
// Deletes the currently wrapped TFE_Context, swaps it with ctx,
|
||||
// and takes ownership of ctx.
|
||||
void Reset(TFE_Context* ctx) { ctx_.reset(ctx); }
|
||||
|
||||
// Returns the TFE_Context that this object wraps. This object
|
||||
// retains ownership of the pointer.
|
||||
TFE_Context* GetTFEContext() const { return ctx_.get(); }
|
||||
|
||||
// Runtime is not copyable
|
||||
Runtime(const Runtime&) = delete;
|
||||
Runtime& operator=(const Runtime&) = delete;
|
||||
|
||||
struct TFEContextDeleter {
|
||||
void operator()(TFE_Context* p) const { TFE_DeleteContext(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_Context, TFEContextDeleter> ctx_;
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
86
tensorflow/cc/experimental/base/public/runtime_builder.h
Normal file
86
tensorflow/cc/experimental/base/public/runtime_builder.h
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||
// Use this to set configuration options, like threadpool size, etc.
|
||||
class RuntimeBuilder {
|
||||
public:
|
||||
RuntimeBuilder() : options_(TFE_NewContextOptions()) {}
|
||||
|
||||
// If `use_tfrt` is true, we will use the new Tensorflow Runtime
|
||||
// (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as
|
||||
// our runtime implementation.
|
||||
RuntimeBuilder& SetUseTFRT(bool use_tfrt);
|
||||
|
||||
// Build a Tensorflow Runtime.
|
||||
//
|
||||
// Params:
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// unique_ptr<tensorflow::cc::Runtime>.
|
||||
std::unique_ptr<Runtime> Build(Status* status);
|
||||
|
||||
// RuntimeBuilder is movable, but not copyable.
|
||||
RuntimeBuilder(RuntimeBuilder&&) = default;
|
||||
RuntimeBuilder& operator=(RuntimeBuilder&&) = default;
|
||||
|
||||
private:
|
||||
// RuntimeBuilder is not copyable
|
||||
RuntimeBuilder(const RuntimeBuilder&) = delete;
|
||||
RuntimeBuilder& operator=(const RuntimeBuilder&) = delete;
|
||||
|
||||
struct TFEContextOptionsDeleter {
|
||||
void operator()(TFE_ContextOptions* p) const {
|
||||
TFE_DeleteContextOptions(p);
|
||||
}
|
||||
};
|
||||
std::unique_ptr<TFE_ContextOptions, TFEContextOptionsDeleter> options_;
|
||||
};
|
||||
|
||||
inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) {
|
||||
TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt);
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
||||
TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
// We can't use std::make_unique here because of its interaction with a
|
||||
// private constructor: https://abseil.io/tips/134
|
||||
return std::unique_ptr<Runtime>(new Runtime(result));
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
96
tensorflow/cc/experimental/base/public/status.h
Normal file
96
tensorflow/cc/experimental/base/public/status.h
Normal file
@ -0,0 +1,96 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Status is a wrapper around an error code and an optional error message.
|
||||
// The set of error codes are defined here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60
|
||||
// Many Tensorflow APIs return a Status, or take a Status as an out parameter.
|
||||
// Clients should check for status.ok() after calling these APIs, and either
|
||||
// handle or propagate the error appropriately.
|
||||
// TODO(bmzhao): Add a detailed code example before moving out of experimental.
|
||||
class Status {
|
||||
public:
|
||||
// Create a success status
|
||||
Status() : status_(TF_NewStatus()) {}
|
||||
|
||||
// Return the status code
|
||||
TF_Code code() const;
|
||||
|
||||
// Returns the error message in Status.
|
||||
std::string message() const;
|
||||
|
||||
// Returns the error message in Status.
|
||||
bool ok() const;
|
||||
|
||||
// Record <code, msg> in Status. Any previous information is lost.
|
||||
// A common use is to clear a status: SetStatus(TF_OK, "");
|
||||
void SetStatus(TF_Code code, const std::string& msg);
|
||||
|
||||
// Status is movable, but not copyable.
|
||||
Status(Status&&) = default;
|
||||
Status& operator=(Status&&) = default;
|
||||
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class Runtime;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TF_Status*, and takes ownership of it.
|
||||
explicit Status(TF_Status* status) : status_(status) {}
|
||||
|
||||
// Status is not copyable
|
||||
Status(const Status&) = delete;
|
||||
Status& operator=(const Status&) = delete;
|
||||
|
||||
// Returns the TF_Status that this object wraps. This object
|
||||
// retains ownership of the pointer.
|
||||
TF_Status* GetTFStatus() const { return status_.get(); }
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* p) const { TF_DeleteStatus(p); }
|
||||
};
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> status_;
|
||||
};
|
||||
|
||||
inline TF_Code Status::code() const { return TF_GetCode(status_.get()); }
|
||||
|
||||
inline std::string Status::message() const {
|
||||
return std::string(TF_Message(status_.get()));
|
||||
}
|
||||
|
||||
inline bool Status::ok() const { return code() == TF_OK; }
|
||||
|
||||
inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
||||
TF_SetStatus(status_.get(), code, msg.c_str());
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
175
tensorflow/cc/experimental/base/public/tensor.h
Normal file
175
tensorflow/cc/experimental/base/public/tensor.h
Normal file
@ -0,0 +1,175 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Tensor represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
using DeleterCallback = std::function<void(void*, size_t)>;
|
||||
|
||||
// Constructs a Tensor from user provided buffer.
|
||||
//
|
||||
// Params:
|
||||
// dtype - The dtype of the tensor's data.
|
||||
// shape - A shape vector, where each element corresponds to the size of
|
||||
// the tensor's corresponding dimension.
|
||||
// data - Pointer to a buffer of memory to construct a Tensor out of.
|
||||
// len - The length (in bytes) of `data`
|
||||
// deleter - A std::function to be called when the Tensor no longer needs the
|
||||
// memory in `data`. This can be used to free `data`, or
|
||||
// perhaps decrement a refcount associated with `data`, etc.
|
||||
// status - Set to OK on success and an error on failure.
|
||||
// Returns:
|
||||
// If an error occurred, status->ok() will be false, and the returned
|
||||
// Tensor must not be used.
|
||||
// TODO(bmzhao): Add Runtime as an argument to this function so we can swap to
|
||||
// a TFRT backed tensor.
|
||||
// TODO(bmzhao): Add benchmarks on overhead for this function; we can
|
||||
// consider using int64_t* + length rather than vector.
|
||||
static Tensor FromBuffer(TF_DataType dtype, const std::vector<int64_t>& shape,
|
||||
void* data, size_t len, DeleterCallback deleter,
|
||||
Status* status);
|
||||
|
||||
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
|
||||
// we should offer a way to deep copy the tensor into a new tensor, which
|
||||
// owns the underlying memory. This could be a .deepcopy()/clone() method.
|
||||
|
||||
// TODO(bmzhao): In the future, we want to relax the non-copyability
|
||||
// constraint. To do so, we can add a C API function that acts like
|
||||
// CopyFrom:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
|
||||
|
||||
// Tensor is movable, but not copyable
|
||||
Tensor(Tensor&&) = default;
|
||||
Tensor& operator=(Tensor&&) = default;
|
||||
|
||||
// Returns the number of dimensions in the tensor. Can be -1, which represents
|
||||
// unknown rank.
|
||||
int dims() const;
|
||||
|
||||
// Returns the number of elements in in demension `d`.
|
||||
// REQUIRES: `0 <= d < dims()`
|
||||
int64_t dim_size(int d) const;
|
||||
|
||||
// Returns a pointer to the underlying data buffer.
|
||||
void* data() const;
|
||||
|
||||
// Returns the data type of the tensor.
|
||||
TF_DataType dtype() const;
|
||||
|
||||
// Returns the number of elements in the tensor. For a tensor with a partially
|
||||
// defined shape, -1 means not fully defined.
|
||||
int64_t num_elements() const;
|
||||
|
||||
// Returns the size of the underlying data in bytes.
|
||||
size_t num_bytes() const;
|
||||
|
||||
private:
|
||||
friend class TensorHandle;
|
||||
friend class Runtime;
|
||||
|
||||
// Wraps a TF_Tensor. Takes ownership of handle.
|
||||
explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {}
|
||||
|
||||
// Tensor is not copyable
|
||||
Tensor(const Tensor&) = delete;
|
||||
Tensor& operator=(const Tensor&) = delete;
|
||||
|
||||
// Returns the underlying TF_Tensor that this object wraps.
|
||||
// This object retains ownership of the pointer.
|
||||
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
|
||||
|
||||
struct DeleterStruct {
|
||||
std::function<void(void*, size_t)> deleter;
|
||||
};
|
||||
|
||||
static void DeleterFunction(void* memory, size_t len, void* deleter_struct) {
|
||||
DeleterStruct* deleter = reinterpret_cast<DeleterStruct*>(deleter_struct);
|
||||
deleter->deleter(memory, len);
|
||||
delete deleter;
|
||||
}
|
||||
|
||||
struct TFTensorDeleter {
|
||||
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
|
||||
};
|
||||
std::unique_ptr<TF_Tensor, TFTensorDeleter> tensor_;
|
||||
};
|
||||
|
||||
inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); }
|
||||
|
||||
inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); }
|
||||
|
||||
inline int64_t Tensor::dim_size(int d) const {
|
||||
return TF_Dim(tensor_.get(), d);
|
||||
}
|
||||
|
||||
inline TF_DataType Tensor::dtype() const {
|
||||
return TF_TensorType(tensor_.get());
|
||||
}
|
||||
|
||||
inline int64_t Tensor::num_elements() const {
|
||||
return TF_TensorElementCount(tensor_.get());
|
||||
}
|
||||
|
||||
inline size_t Tensor::num_bytes() const {
|
||||
return TF_TensorByteSize(tensor_.get());
|
||||
}
|
||||
|
||||
inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
const std::vector<int64_t>& shape, void* data,
|
||||
size_t len, DeleterCallback deleter,
|
||||
Status* status) {
|
||||
// Credit to apassos@ for this technique:
|
||||
// Despite the fact that our API takes a std::function deleter, we are able
|
||||
// to maintain ABI stability because:
|
||||
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
|
||||
// 2. DeleterFunction is defined in the same build artifact that constructed
|
||||
// the std::function (so there isn't confusion about std::function ABI).
|
||||
// Note that 2. is satisifed by the fact that this is a header-only API, where
|
||||
// the function implementations are inline.
|
||||
|
||||
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
|
||||
TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len,
|
||||
&DeleterFunction, deleter_struct);
|
||||
if (tensor == nullptr) {
|
||||
status->SetStatus(TF_INVALID_ARGUMENT,
|
||||
"Failed to create tensor for input buffer");
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// An opaque representation of a tensor computed/managed by the Tensorflow
|
||||
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
|
||||
// to tensors placed in memory of different devices or remote address spaces.
|
||||
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
|
||||
// from it.
|
||||
class TensorHandle {
|
||||
public:
|
||||
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
|
||||
// status->ok() will be false, and the returned Tensor must not be used.
|
||||
Tensor Resolve(Status* status);
|
||||
|
||||
// Constructs a TensorHandle from a Tensor. If an error occurred,
|
||||
// status->ok() will be false, and the returned TensorHandle must not be used.
|
||||
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
|
||||
Status* status);
|
||||
|
||||
// TensorHandle is movable, and not copyable
|
||||
TensorHandle(TensorHandle&&) = default;
|
||||
TensorHandle& operator=(TensorHandle&&) = default;
|
||||
|
||||
private:
|
||||
// Wraps a TFE_TensorHandle. Takes ownership of handle.
|
||||
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
|
||||
|
||||
// TensorHandle is not copyable
|
||||
TensorHandle(const TensorHandle&) = delete;
|
||||
TensorHandle& operator=(const TensorHandle&) = delete;
|
||||
|
||||
// Returns the underlying TFE_TensorHandle that this object wraps.
|
||||
// This object retains ownership of the pointer.
|
||||
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
|
||||
|
||||
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
|
||||
// and takes ownership of handle.
|
||||
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
|
||||
|
||||
struct TFETensorHandleDeleter {
|
||||
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
|
||||
};
|
||||
|
||||
inline Tensor TensorHandle::Resolve(Status* status) {
|
||||
TF_Tensor* tensor =
|
||||
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
|
||||
const Runtime& runtime,
|
||||
Status* status) {
|
||||
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
|
||||
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return TensorHandle(nullptr);
|
||||
}
|
||||
return TensorHandle(tensor_handle);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
@ -0,0 +1,50 @@
|
||||
# Tests for the C++ header-only base types.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_types_test_util",
|
||||
testonly = True,
|
||||
hdrs = ["tensor_types_test_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_datatype",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensor_test",
|
||||
srcs = [
|
||||
"tensor_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensorhandle_test",
|
||||
srcs = [
|
||||
"tensorhandle_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/cc/experimental/base/public:tensorhandle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
@ -0,0 +1,163 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
bool done = false;
|
||||
Status status;
|
||||
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||
{
|
||||
// data_vector is a rank 1 tensor.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(data_vector.size());
|
||||
|
||||
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
done = true;
|
||||
};
|
||||
|
||||
Tensor tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
/*data=*/data_vector.data(),
|
||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||
/*deleter=*/callback, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
}
|
||||
// At this point, tensor has been destroyed, and the deleter callback should
|
||||
// have run.
|
||||
EXPECT_TRUE(done);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
using tensorflow::experimental::cc::TensorHandle;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
|
||||
// verify the expected dims, dtype, value, num bytes, and num elements.
|
||||
TYPED_TEST(ConstructScalarTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor original_tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -13,19 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/while_gradients.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/while_context.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
|
40
tensorflow/cc/gradients/manip_grad.cc
Normal file
40
tensorflow/cc/gradients/manip_grad.cc
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/manip_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
|
||||
Status RollGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto shift = op.input(1);
|
||||
auto axis = op.input(2);
|
||||
auto grad_op = Roll(scope, grad_inputs[0], Neg(scope, shift), axis);
|
||||
grad_outputs->push_back(grad_op);
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Roll", RollGrad);
|
||||
|
||||
} // namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user