Merge branch 'master' into nhasabni/mkl_tanh
This commit is contained in:
commit
1af63efd38
58
.bazelrc
58
.bazelrc
@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
|
||||
build:mkl_threadpool -c opt
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Enable using platform specific build settings
|
||||
# Enable using platform specific build settings, except when cross-compiling for
|
||||
# mobile platforms.
|
||||
build --enable_platform_specific_config
|
||||
build:android --noenable_platform_specific_config
|
||||
build:ios --noenable_platform_specific_config
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:android --copt=-w
|
||||
build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build:android --cxxopt=-std=c++14
|
||||
build:android --host_cxxopt=-std=c++14
|
||||
build:ios --cxxopt=-std=c++14
|
||||
build:ios --host_cxxopt=-std=c++14
|
||||
build:linux --cxxopt=-std=c++14
|
||||
build:linux --host_cxxopt=-std=c++14
|
||||
build:macos --cxxopt=-std=c++14
|
||||
@ -372,32 +386,32 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
|
||||
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
|
||||
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
|
@ -1 +1 @@
|
||||
3.0.0
|
||||
3.1.0
|
||||
|
87
.github/bot_config.yml
vendored
Normal file
87
.github/bot_config.yml
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
- amahendrakar
|
||||
- ravikyram
|
||||
- Saduf2019
|
||||
# A list of assignees for compiler folder
|
||||
compiler_assignees:
|
||||
- joker-eph
|
||||
# Cuda Comment
|
||||
cuda_comment: >
|
||||
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
|
||||
* For TF-GPU - See point 1
|
||||
* For TF-CPU - See point 2
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
|
||||
|
||||
|
||||
Make sure you are using compatible TF and CUDA versions.
|
||||
Please refer following TF version and CUDA version compatibility table.
|
||||
|
||||
| TF | CUDA |
|
||||
|
||||
| :-------------: | :-------------: |
|
||||
|
||||
| 2.1.0 - 2.2.0 | 10.1 |
|
||||
|
||||
| 1.13.1 - 2.0 | 10.0 |
|
||||
|
||||
| 1.5.0 - 1.12.0 | 9.0 |
|
||||
|
||||
* If you have above configuration and using _**Windows**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
|
||||
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
|
||||
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
|
||||
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
|
||||
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
|
||||
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
|
||||
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
|
||||
|
||||
|
||||
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
|
||||
|
||||
|
||||
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
|
||||
|
||||
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
|
||||
|
||||
* Try Google Colab to use TensorFlow.
|
||||
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
|
||||
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
|
||||
* All you need is a good internet connection and you are all set.
|
||||
* Try to build TF from sources by changing CPU optimization flags.
|
||||
|
||||
*Please let us know if this helps.*
|
||||
|
||||
windows_comment: >
|
||||
From the stack trace it looks like you are hitting windows path length limit.
|
||||
* Try to disable path length limit on Windows 10.
|
||||
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
|
||||
|
||||
Please let us know if this helps.
|
@ -1,7 +1,11 @@
|
||||
# TensorFlow Code of Conduct
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to make participation in our project and our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, gender identity and expression, level of
|
||||
experience, nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
|
@ -4,18 +4,23 @@ https://stackoverflow.com/questions/tagged/tensorflow
|
||||
|
||||
If you open a GitHub issue, here is our policy:
|
||||
|
||||
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
|
||||
1. It must be a bug, a feature request, or a significant problem with the
|
||||
documentation (for small docs fixes please send a PR instead).
|
||||
2. The form below must be filled out.
|
||||
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).
|
||||
3. It shouldn't be a TensorBoard issue. Those go
|
||||
[here](https://github.com/tensorflow/tensorboard/issues).
|
||||
|
||||
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
|
||||
|
||||
------------------------
|
||||
|
||||
### System information
|
||||
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
|
||||
|
||||
- **Have I written custom code (as opposed to using a stock example script
|
||||
provided in TensorFlow)**:
|
||||
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
|
||||
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**:
|
||||
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue
|
||||
happens on a mobile device**:
|
||||
- **TensorFlow installed from (source or binary)**:
|
||||
- **TensorFlow version (use command below)**:
|
||||
- **Python version**:
|
||||
|
@ -104,7 +104,7 @@ open-source software development:
|
||||
### Official Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
@ -112,8 +112,8 @@ Build Type | Status
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
@ -142,6 +142,7 @@ Build Type | Status
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
|
193
RELEASE.md
193
RELEASE.md
@ -1,3 +1,41 @@
|
||||
# Release 2.3.0
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
|
||||
models will not be impacted.
|
||||
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
|
||||
|
||||
# Release 2.0.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 1.15.3
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 2.2.0
|
||||
|
||||
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
|
||||
@ -52,89 +90,150 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
|
||||
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* `tf.data`:
|
||||
* Removed `autotune_algorithm` from experimental optimization options.
|
||||
* TF Core:
|
||||
* `tf.constant` always creates CPU tensors irrespective of the current device context.
|
||||
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
|
||||
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
|
||||
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
|
||||
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
|
||||
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
|
||||
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
|
||||
* `tf.constant` always creates CPU tensors irrespective of the current
|
||||
device context.
|
||||
* Eager `TensorHandles` maintain a list of mirrors for any copies to local
|
||||
or remote devices. This avoids any redundant copies due to op execution.
|
||||
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer
|
||||
experimental and is available as simply `.ref()`.
|
||||
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops.
|
||||
Vectorizing `tf.cond` is also supported now.
|
||||
* Set as much partial shape as we can infer statically within the gradient
|
||||
impl of the gather op.
|
||||
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body
|
||||
functions are stateless. This allows multiple gradients while ops to run
|
||||
in parallel under distribution strategy.
|
||||
* Speed up `GradientTape` in eager mode by auto-generating list of op
|
||||
inputs/outputs which are unused and hence not cached for gradient
|
||||
functions.
|
||||
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
|
||||
* Improve error message when attempting to use `None` in data-dependent control flow.
|
||||
* Improve error message when attempting to use `None` in data-dependent
|
||||
control flow.
|
||||
* Add `RaggedTensor.numpy()`.
|
||||
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
|
||||
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
|
||||
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
|
||||
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow
|
||||
indexing into uniform dimensions.
|
||||
* Update `tf.expand_dims` to always insert the new dimension as a
|
||||
non-ragged dimension.
|
||||
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm`
|
||||
when `ids` is ragged.
|
||||
* Allow `batch_dims==rank(indices)` in `tf.gather`.
|
||||
* Add support for bfloat16 in `tf.print`.
|
||||
* `tf.distribute`:
|
||||
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
|
||||
* Support `embedding_column` with variable-length input features for
|
||||
`MultiWorkerMirroredStrategy`.
|
||||
* `tf.keras`:
|
||||
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
|
||||
* Added `experimental_aggregate_gradients` argument to
|
||||
`tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom
|
||||
gradient aggregation and processing aggregated gradients in custom
|
||||
training loop.
|
||||
* Allow `pathlib.Path` paths for loading models via Keras API.
|
||||
* `tf.function`/AutoGraph:
|
||||
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
|
||||
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
|
||||
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
|
||||
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
|
||||
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
|
||||
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
|
||||
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
|
||||
* AutoGraph is now available in `ReplicaContext.merge_call`,
|
||||
`Strategy.extended.update` and `Strategy.extended.update_non_slot`.
|
||||
* Experimental support for shape invariants has been enabled in
|
||||
`tf.function`. See the API docs for
|
||||
`tf.autograph.experimental.set_loop_options` for additonal info.
|
||||
* AutoGraph error messages now exclude frames corresponding to APIs
|
||||
internal to AutoGraph.
|
||||
* Improve shape inference for `tf.function` input arguments to unlock more
|
||||
Grappler optimizations in TensorFlow 2.x.
|
||||
* Improve automatic control dependency management of resources by allowing
|
||||
resource reads to occur in parallel and synchronizing only on writes.
|
||||
* Fix execution order of multiple stateful calls to `experimental_run_v2`
|
||||
in `tf.function`.
|
||||
* You can now iterate over `RaggedTensors` using a for loop inside
|
||||
`tf.function`.
|
||||
* `tf.lite`:
|
||||
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
|
||||
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
|
||||
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
|
||||
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
|
||||
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
|
||||
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android
|
||||
10
|
||||
* TFLite Android AARs now include the C headers and APIs are required to
|
||||
use TFLite from native code.
|
||||
* Refactors the delegate and delegate kernel sources to allow usage in the
|
||||
linter.
|
||||
* Limit delegated ops to actually supported ones if a device name is
|
||||
specified or `NNAPI` CPU Fallback is disabled.
|
||||
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
|
||||
* TFLite's unpack op now supports boolean tensor inputs.
|
||||
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
|
||||
* Microcontroller and embedded code moved from experimental to main
|
||||
TensorFlow Lite folder
|
||||
* Check for large TFLite tensors.
|
||||
* Fix GPU delegate crash with C++17.
|
||||
* Add 5D support to TFLite `strided_slice`.
|
||||
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
|
||||
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
|
||||
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
|
||||
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
|
||||
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
|
||||
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
|
||||
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to
|
||||
be accelerated.
|
||||
* Fix segmentation fault when running a model with LSTM nodes using
|
||||
`NNAPI` Delegate
|
||||
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum
|
||||
operation is a scalar.
|
||||
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a
|
||||
scalar.
|
||||
* Expose option to limit the number of partitions that will be delegated
|
||||
to `NNAPI`.
|
||||
* If a target accelerator is specified, use its feature level to determine
|
||||
operations to delegate instead of SDK version.
|
||||
* `tf.random`:
|
||||
* Various random number generation improvements:
|
||||
* Add a fast path for default `random_uniform`
|
||||
* `random_seed` documentation improvement.
|
||||
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
|
||||
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
|
||||
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
|
||||
* `RandomBinomial` broadcasts and appends the sample shape to the left
|
||||
rather than the right.
|
||||
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`,
|
||||
`tf.random.stateless_poisson`
|
||||
* `tf.random.stateless_uniform` now supports unbounded sampling of `int`
|
||||
types.
|
||||
* Math and Linear Algebra:
|
||||
* Add `tf.linalg.LinearOperatorTridiag`.
|
||||
* Add `LinearOperatorBlockLowerTriangular`
|
||||
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
|
||||
* Add broadcasting support to
|
||||
tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204),
|
||||
tf.math.invert_permutation.
|
||||
* Add `tf.math.sobol_sample` op.
|
||||
* Add `tf.math.xlog1py`.
|
||||
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
|
||||
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
|
||||
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to
|
||||
`tf.signal`.
|
||||
* TPU Enhancements:
|
||||
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
|
||||
* Refactor `TpuClusterResolver` to move shared logic to a separate pip
|
||||
package.
|
||||
* Support configuring TPU software version from cloud tpu client.
|
||||
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
|
||||
* Allowed TPU embedding weight decay factor to be multiplied by learning
|
||||
rate.
|
||||
* XLA Support:
|
||||
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
|
||||
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
|
||||
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
|
||||
* Add standalone XLA AOT runtime target + relevant .cc sources to pip
|
||||
package.
|
||||
* Add check for memory alignment to MemoryAllocation::MemoryAllocation()
|
||||
on 32-bit ARM. This ensures a deterministic early exit instead of a hard
|
||||
to debug bus error later.
|
||||
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to
|
||||
XLA header+object files and include them in your C++ programs.
|
||||
* Enable `Igamma`, `Igammac` for XLA.
|
||||
* Deterministic Op Functionality:
|
||||
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
|
||||
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
|
||||
* XLA reduction emitter is deterministic when the environment variable
|
||||
`TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends
|
||||
deterministic `tf.nn.bias_add` back-prop functionality (and therefore
|
||||
also deterministic back-prop of bias-addition in Keras layers) to
|
||||
include when XLA JIT compilation is enabled.
|
||||
* Fix problem, when running on a CUDA GPU and when either environment
|
||||
variable `TF_DETERMINSTIC_OPS` or environment variable
|
||||
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
|
||||
configurations led to an exception with the message "No algorithm
|
||||
worked!"
|
||||
* Tracing and Debugging:
|
||||
* Add source, destination name to `_send` traceme to allow easier debugging.
|
||||
* Add source, destination name to `_send` traceme to allow easier
|
||||
debugging.
|
||||
* Add traceme event to `fastpathexecute`.
|
||||
* Other:
|
||||
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
|
||||
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
|
||||
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
|
||||
* Fix an issue with AUC.reset_states for multi-label AUC
|
||||
[#35852](https://github.com/tensorflow/tensorflow/issues/35852)
|
||||
* Fix the TF upgrade script to not delete files when there is a parsing
|
||||
error and the output mode is `in-place`.
|
||||
* Move `tensorflow/core:framework/*_pyclif` rules to
|
||||
`tensorflow/core/framework:*_pyclif`.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
|
||||
|
||||
It is possible to write models that are secure in a sense that they can safely
|
||||
process untrusted inputs assuming there are no bugs. There are two main reasons
|
||||
to not rely on this: first, it is easy to write models which must not be exposed
|
||||
to not rely on this: First, it is easy to write models which must not be exposed
|
||||
to untrusted inputs, and second, there are bugs in any software system of
|
||||
sufficient complexity. Letting users control inputs could allow them to trigger
|
||||
bugs either in TensorFlow or in dependent libraries.
|
||||
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
|
||||
vulnerability in TensorFlow (although it would be a vulnerability of this
|
||||
hypothetical system).
|
||||
|
||||
As a general rule, it is incorrect behavior for Tensorflow to access memory it
|
||||
As a general rule, it is incorrect behavior for TensorFlow to access memory it
|
||||
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
||||
such behaviors constitute a vulnerability.
|
||||
|
||||
|
@ -114,6 +114,14 @@ http_archive(
|
||||
],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "person_detect_data",
|
||||
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
|
||||
],
|
||||
)
|
||||
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
70
configure.py
70
configure.py
@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MIN_BAZEL_VERSION = '3.1.0'
|
||||
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
@ -144,7 +144,7 @@ def write_to_bazelrc(line):
|
||||
|
||||
|
||||
def write_action_env_to_bazelrc(var_name, var):
|
||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
||||
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
|
||||
|
||||
|
||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||
@ -205,7 +205,7 @@ def setup_python(environ_cp):
|
||||
# Get PYTHON_BIN_PATH, default is the current running python.
|
||||
default_python_bin_path = sys.executable
|
||||
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
||||
'%s]: ') % default_python_bin_path
|
||||
'{}]: ').format(default_python_bin_path)
|
||||
while True:
|
||||
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
||||
'PYTHON_BIN_PATH',
|
||||
@ -215,9 +215,10 @@ def setup_python(environ_cp):
|
||||
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
||||
break
|
||||
elif not os.path.exists(python_bin_path):
|
||||
print('Invalid python path: %s cannot be found.' % python_bin_path)
|
||||
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
|
||||
else:
|
||||
print('%s is not executable. Is it the python binary?' % python_bin_path)
|
||||
print('{} is not executable. Is it the python binary?'.format(
|
||||
python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = ''
|
||||
|
||||
# Convert python path to Windows style before checking lib and version
|
||||
@ -236,7 +237,7 @@ def setup_python(environ_cp):
|
||||
default_python_lib_path = python_lib_paths[0]
|
||||
python_lib_path = get_input(
|
||||
'Please input the desired Python library path to use. '
|
||||
'Default is [%s]\n' % python_lib_paths[0])
|
||||
'Default is [{}]\n'.format(python_lib_paths[0]))
|
||||
if not python_lib_path:
|
||||
python_lib_path = default_python_lib_path
|
||||
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
||||
@ -252,7 +253,7 @@ def setup_python(environ_cp):
|
||||
# Set-up env variables used by python_configure.bzl
|
||||
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
|
||||
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
|
||||
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
||||
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
||||
|
||||
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
||||
@ -266,7 +267,7 @@ def setup_python(environ_cp):
|
||||
with open(
|
||||
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
||||
'w') as f:
|
||||
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
|
||||
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
|
||||
|
||||
|
||||
def reset_tf_configure_bazelrc():
|
||||
@ -320,11 +321,12 @@ def get_var(environ_cp,
|
||||
Raise the error to avoid infinitely looping.
|
||||
"""
|
||||
if not question:
|
||||
question = 'Do you wish to build TensorFlow with %s support?' % query_item
|
||||
question = 'Do you wish to build TensorFlow with {} support?'.format(
|
||||
query_item)
|
||||
if not yes_reply:
|
||||
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
|
||||
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
|
||||
if not no_reply:
|
||||
no_reply = 'No %s' % yes_reply
|
||||
no_reply = 'No {}'.format(yes_reply)
|
||||
|
||||
yes_reply += '\n'
|
||||
no_reply += '\n'
|
||||
@ -368,7 +370,7 @@ def get_var(environ_cp,
|
||||
print(no_reply)
|
||||
var = False
|
||||
else:
|
||||
print('Invalid selection: %s' % user_input_origin)
|
||||
print('Invalid selection: {}'.format(user_input_origin))
|
||||
return var
|
||||
|
||||
|
||||
@ -1009,17 +1011,15 @@ def set_tf_cuda_compute_capabilities(environ_cp):
|
||||
default_cuda_compute_capabilities = native_cuda_compute_capabilities
|
||||
|
||||
ask_cuda_compute_capabilities = (
|
||||
'Please specify a list of comma-separated '
|
||||
'CUDA compute capabilities you want to '
|
||||
'build with.\nYou can find the compute '
|
||||
'capability of your device at: '
|
||||
'https://developer.nvidia.com/cuda-gpus.\nPlease'
|
||||
' note that each additional compute '
|
||||
'capability significantly increases your '
|
||||
'build time and binary size, and that '
|
||||
'TensorFlow only supports compute '
|
||||
'capabilities >= 3.5 [Default is: %s]: ' %
|
||||
default_cuda_compute_capabilities)
|
||||
'Please specify a list of comma-separated CUDA compute capabilities '
|
||||
'you want to build with.\nYou can find the compute capability of your '
|
||||
'device at: https://developer.nvidia.com/cuda-gpus. Each capability '
|
||||
'can be specified as "x.y" or "compute_xy" to include both virtual and'
|
||||
' binary GPU code, or as "sm_xy" to only include the binary '
|
||||
'code.\nPlease note that each additional compute capability '
|
||||
'significantly increases your build time and binary size, and that '
|
||||
'TensorFlow only supports compute capabilities >= 3.5 [Default is: '
|
||||
'%s]: ' % default_cuda_compute_capabilities)
|
||||
tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
|
||||
ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
|
||||
@ -1031,8 +1031,23 @@ def set_tf_cuda_compute_capabilities(environ_cp):
|
||||
for compute_capability in tf_cuda_compute_capabilities.split(','):
|
||||
m = re.match('[0-9]+.[0-9]+', compute_capability)
|
||||
if not m:
|
||||
# We now support sm_35,sm_50,sm_60,compute_70.
|
||||
sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)',
|
||||
compute_capability)
|
||||
if not sm_compute_match:
|
||||
print('Invalid compute capability: %s' % compute_capability)
|
||||
all_valid = False
|
||||
else:
|
||||
ver = int(sm_compute_match.group(2))
|
||||
if ver < 30:
|
||||
print(
|
||||
'ERROR: TensorFlow only supports small CUDA compute'
|
||||
' capabilities of sm_30 and higher. Please re-specify the list'
|
||||
' of compute capabilities excluding version %s.' % ver)
|
||||
all_valid = False
|
||||
if ver < 35:
|
||||
print('WARNING: XLA does not support CUDA compute capabilities '
|
||||
'lower than sm_35. Disable XLA when running on older GPUs.')
|
||||
else:
|
||||
ver = float(m.group(0))
|
||||
if ver < 3.0:
|
||||
@ -1223,7 +1238,8 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
|
||||
compile times, but until 16.4 is officially released, we can't depend on it.
|
||||
|
||||
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
See also
|
||||
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
Because it's very annoying to check this manually (to check the MSVC installed
|
||||
versions, you need to use the registry, and it's not clear if Bazel will be
|
||||
@ -1366,8 +1382,13 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
try:
|
||||
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
|
||||
_TF_MAX_BAZEL_VERSION)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print('Error checking bazel version: ', e.output.decode('UTF-8').strip())
|
||||
raise e
|
||||
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
@ -1385,7 +1406,6 @@ def main():
|
||||
# Windows.
|
||||
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_MPI'] = '0'
|
||||
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
@ -524,12 +524,22 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
package_group(
|
||||
name = "ndarray_tensor_allow_list",
|
||||
packages = ["//learning/pathways/..."],
|
||||
)
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
package_group(
|
||||
name = "types_whitelist",
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
|
@ -85,7 +85,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:chromiumos": [
|
||||
":tf_attrtype",
|
||||
@ -182,7 +182,7 @@ tf_cuda_library(
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status",
|
||||
@ -216,10 +216,11 @@ tf_cuda_library(
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
@ -232,12 +233,13 @@ cc_library(
|
||||
srcs = ["tf_status.cc"],
|
||||
hdrs = ["tf_status.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
@ -259,10 +261,15 @@ cc_library(
|
||||
name = "tensor_interface",
|
||||
hdrs = ["tensor_interface.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -272,7 +279,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -286,16 +293,17 @@ cc_library(
|
||||
srcs = ["tf_tensor.cc"],
|
||||
hdrs = ["tf_tensor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -311,14 +319,15 @@ tf_cuda_library(
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts",
|
||||
@ -386,8 +395,14 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -426,7 +441,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -457,7 +472,7 @@ tf_cuda_library(
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
@ -484,7 +499,7 @@ tf_cuda_library(
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -589,6 +589,7 @@ void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
|
||||
|
||||
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
|
||||
TF_DeviceList* response = new TF_DeviceList;
|
||||
if (session && session->session)
|
||||
status->status = session->session->ListDevices(&response->response);
|
||||
return response;
|
||||
}
|
||||
@ -596,6 +597,7 @@ TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
|
||||
TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
|
||||
TF_Status* status) {
|
||||
TF_DeviceList* response = new TF_DeviceList;
|
||||
if (session && session->session)
|
||||
status->status = session->session->ListDevices(&response->response);
|
||||
return response;
|
||||
}
|
||||
@ -1384,6 +1386,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
||||
cpp_type v; \
|
||||
status->status = \
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
|
||||
if (!status->status.ok()) return; \
|
||||
*value = static_cast<c_type>(v); \
|
||||
} \
|
||||
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
|
||||
@ -2178,6 +2181,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
|
||||
}
|
||||
return new_session;
|
||||
} else {
|
||||
LOG(ERROR) << status->status;
|
||||
DCHECK_EQ(nullptr, session);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
|
||||
TF_Status* status) {
|
||||
auto* opts = TFE_NewContextOptions();
|
||||
|
||||
// Reduce GPU memory allocation, and set appropriate config options for TFE
|
||||
// context.
|
||||
auto* config = TF_CreateConfig(
|
||||
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
10);
|
||||
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
|
||||
if (!status->status.ok()) {
|
||||
CHECK(!config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* ctx = TFE_NewContextFromSession(opts, session, status);
|
||||
TF_DeleteBuffer(config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// TODO: retrieve the device string via TFE_ContextListDevices()
|
||||
static const char DEFAULT_CPU_DEVICE[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
|
||||
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
int tensor_id, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
|
||||
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
|
||||
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
|
||||
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
|
||||
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
|
||||
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
|
||||
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
|
||||
shared_name.size());
|
||||
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
|
||||
|
||||
// TODO: consider making this an unknown shape.
|
||||
const int64_t* dims_ptr = nullptr;
|
||||
int num_dims = 0;
|
||||
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
|
||||
/*num_values*/ 0, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* queue = nullptr;
|
||||
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
|
||||
return queue;
|
||||
}
|
||||
|
||||
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, tensor, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
|
||||
if (!status->status.ok()) return;
|
||||
CHECK_EQ(num_retvals, 0);
|
||||
}
|
||||
|
||||
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
|
||||
TF_DataType inputType,
|
||||
TFE_TensorHandle* queue,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
TFE_TensorHandle* ret;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &ret, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||
status->status = tensorflow::errors::Internal(errMsg);
|
||||
}
|
||||
@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
|
||||
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
auto iter = builder->attr_names.insert(attr_name).first;
|
||||
builder->Set(
|
||||
(*iter).c_str(),
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values),
|
||||
num_values));
|
||||
}
|
||||
|
||||
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
|
||||
|
@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
|
||||
// Create a serialized tensorflow.ServerDef proto.
|
||||
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
|
||||
|
||||
// TODO: remove this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
|
||||
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
|
||||
|
||||
// Creates from `session` a new eager context to run a graph function or
|
||||
// sends/recvs, so that these concurrent TFE executions can share (via
|
||||
// `session` and its associated device mgr) the same set of fifo queue resource
|
||||
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
|
||||
// graph function execution can access the same fifo queue resource handles
|
||||
// (associated with devices managed by the device manager, which can be obtained
|
||||
// from `session`).
|
||||
//
|
||||
// TODO: Remove this function once we migrate away from using session.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
|
||||
TF_Session* session, TF_Status* status);
|
||||
|
||||
// TODO: Retire this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
|
||||
TF_Session* session, int tensor_id, TF_DataType inputType,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO: consider folding the 2 APIs below into the ones above.
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||
TF_Session* session, int tensor_id, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||
const char* errMsg);
|
||||
|
||||
|
@ -35,7 +35,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
@ -144,6 +144,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_unified_internal",
|
||||
hdrs = [
|
||||
"c_api_unified_experimental_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
@ -319,6 +337,7 @@ tf_cuda_cc_test(
|
||||
tags = [
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
"notap", # TODO(b/156981931): flaky
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -349,7 +368,10 @@ tf_cuda_cc_test(
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan", # leaks gRPC server instances
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
@ -357,10 +379,46 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_distributed_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_distributed_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan", # leaks gRPC server instances
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -376,7 +434,10 @@ tf_cuda_cc_test(
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan", # leaks gRPC server instances
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
@ -412,7 +473,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api",
|
||||
@ -448,6 +509,8 @@ tf_cuda_library(
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
@ -505,6 +568,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
@ -727,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
TF_Session* sess, TF_Status* status) {
|
||||
const tensorflow::DeviceMgr* device_mgr = nullptr;
|
||||
status->status = sess->session->LocalDeviceManager(&device_mgr);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
@ -899,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->SyncExecutors();
|
||||
status->status = tensorflow::unwrap(ctx)->AsyncWait();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -924,7 +918,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
@ -1403,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function_def);
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function->fdef);
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->RemoveFunction(name);
|
||||
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
|
@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
|
||||
TF_Status* status);
|
||||
// Indicates that the caller will not be using `h` any more.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
|
@ -30,24 +30,11 @@ namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
job_def->mutable_tasks()->at(task_index) =
|
||||
tensorflow::strings::StrCat("localhost:", port);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
@ -101,6 +88,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Read the value of variable `var` and save it into `out_value`.
|
||||
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
|
||||
TFE_TensorHandle** out_value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, out_value, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
@ -243,6 +246,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
|
||||
TFE_TensorHandle* value_handle = nullptr;
|
||||
ReadVariable(ctx, var_handle1, &value_handle);
|
||||
CheckTFE_TensorHandleHasFloats(value_handle, {2});
|
||||
TFE_DeleteTensorHandle(value_handle);
|
||||
|
||||
// Start a new worker to replace task:1
|
||||
ReplaceTaskInServerDef(&server_def, 1);
|
||||
server_def.set_task_index(1);
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
// Update server def to replace the remote device with the device info on the
|
||||
// new worker (different incarnation ID).
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// The device of var_handle0 is local device which is the same before and
|
||||
// after cluster update. Remove resource with valid device should succeed.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle0, status);
|
||||
TFE_OpSetDevice(op, dev0_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
// The device of var_handle1 is remote device, which was replaced during
|
||||
// cluster update. Removing resource with invalid device should fail
|
||||
// gracefully (i.e., with error status) instead of crashing with segfaults.
|
||||
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle1, status);
|
||||
TFE_OpSetDevice(op, dev1_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
@ -282,6 +381,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
@ -310,4 +410,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||
}
|
||||
|
||||
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||
|
||||
const string first_name =
|
||||
keep_localhost_for_first_connect ? "localhost" : "abc";
|
||||
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
|
||||
tensorflow::Status status2;
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
|
||||
|
||||
// Rename local device
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev1_name =
|
||||
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
|
||||
|
||||
// Another renaming of local device
|
||||
const string second_name = "def";
|
||||
server_def.set_job_name(second_name);
|
||||
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
|
||||
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
|
||||
absl::StrCat(second_name, ":",
|
||||
tensorflow::testing::PickUnusedPortOrDie());
|
||||
|
||||
serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle2, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
TFE_DeleteTensorHandle(var_handle2);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||
}
|
||||
|
||||
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
|
||||
|
||||
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
|
||||
|
||||
} // namespace
|
||||
|
506
tensorflow/c/eager/c_api_distributed_test.cc
Normal file
506
tensorflow/c/eager/c_api_distributed_test.cc
Normal file
@ -0,0 +1,506 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
// Add the values of three variables on three different tasks.
|
||||
string AddVariablesFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'AddVariablesFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'sum'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read1'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read2'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add1'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read1:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add2'"
|
||||
" op: 'Add'"
|
||||
" input: 'add1:z:0'"
|
||||
" input: 'read2:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'sum'"
|
||||
" value: 'add2:z:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
||||
CHECK_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
||||
bool initialized = false;
|
||||
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
EXPECT_EQ(initialized, true);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(is_initialized[0]);
|
||||
TFE_DeleteOp(op);
|
||||
delete status;
|
||||
}
|
||||
|
||||
void TestFunctionWithPackedInput(const bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
// Create one variable per task.
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Add a sync point in order to make sure that variables have been initialized
|
||||
// before the function execution starts.
|
||||
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
||||
VarIsInitialized(ctx, h1);
|
||||
VarIsInitialized(ctx, h2);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
TFE_TensorHandle* packed_handle =
|
||||
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||
|
||||
const string composite_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Register and run a function which returns the sum of 3 variables.
|
||||
const string function_def = AddVariablesFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(func, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(packed_handle);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(sum, 6.0);
|
||||
|
||||
TFE_DeleteTensorHandle(h0);
|
||||
TFE_DeleteTensorHandle(h1);
|
||||
TFE_DeleteTensorHandle(h2);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/true);
|
||||
}
|
||||
|
||||
string VariableAddFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'VariableAddFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var0'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'var0_value'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read0:value:0'"
|
||||
" device: '/job:localhost/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'identity'"
|
||||
" op: 'Identity'"
|
||||
" input: 'add:z:0'"
|
||||
" device: '/job:localhost/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'var0_value'"
|
||||
" value: 'identity:output:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
||||
public:
|
||||
FunctionErrorInjectionPass(string error_node, string error_device)
|
||||
: error_node_(error_node), error_device_(error_device) {}
|
||||
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
std::unique_ptr<tensorflow::Graph>* graph,
|
||||
tensorflow::FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override {
|
||||
// Inject failure to function instantiation if finding a node that contains
|
||||
// the given node name (error_node_) and requested device (error_device_).
|
||||
for (const auto node : graph->get()->nodes()) {
|
||||
if (node->name().find(error_node_) != string::npos &&
|
||||
node->requested_device() == error_device_) {
|
||||
return tensorflow::errors::Internal("Injected graph pass error.");
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const string error_node_;
|
||||
const string error_device_;
|
||||
};
|
||||
|
||||
void TestDistributedFunctionCancellation(bool inject_error) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
if (inject_error) {
|
||||
// Inject a function optimization pass failure when it sees the 'read0' op
|
||||
// having a requested device `dev2_name`. During execution:
|
||||
// * task:0 processes the main function `VariableAddFunction` and places
|
||||
// the read0 op on task:2
|
||||
// * task:0 partitions the main function with a subgraph containing read0
|
||||
// sent to task:2
|
||||
// * task:2 graph pass reports an error when it sees read0 with dev2_name
|
||||
tensorflow::function_optimization_registration::
|
||||
FunctionOptimizationPassRegistration register_test_pass(
|
||||
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle, nullptr);
|
||||
|
||||
const string function_def = VariableAddFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, var_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
|
||||
if (inject_error) {
|
||||
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
ASSERT_EQ(sum, 4.0);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(var_handle);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* h1_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h0_task1);
|
||||
TFE_DeleteTensorHandle(h1_task1);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
} // namespace
|
@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
std::move(tensor_handles), context, &handle);
|
||||
return tensorflow::wrap(handle);
|
||||
}
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure soft device placement policy for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure device placement policy logging for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -19,37 +19,22 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void TestRemoteExecute(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
@ -351,260 +336,4 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
}
|
||||
|
||||
// Add the values of three variables on three different tasks.
|
||||
string AddVariablesFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'AddVariablesFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'sum'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read1'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read2'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add1'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read1:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add2'"
|
||||
" op: 'Add'"
|
||||
" input: 'add1:z:0'"
|
||||
" input: 'read2:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'sum'"
|
||||
" value: 'add2:z:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
// Create one variable per task.
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
TFE_TensorHandle* packed_handle =
|
||||
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||
|
||||
const string composite_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Register and run a function which returns the sum of 3 variables.
|
||||
const string function_def = AddVariablesFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(packed_handle);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(sum, 6.0);
|
||||
|
||||
TFE_DeleteTensorHandle(h0);
|
||||
TFE_DeleteTensorHandle(h1);
|
||||
TFE_DeleteTensorHandle(h2);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* h1_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h0_task1);
|
||||
TFE_DeleteTensorHandle(h1_task1);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1203,6 +1203,8 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
h = nullptr;
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(op);
|
||||
|
@ -18,7 +18,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -150,6 +152,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
@ -295,3 +298,23 @@ bool GetDeviceName(TFE_Context* ctx, string* device_name,
|
||||
TF_DeleteDeviceList(devices);
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
// Return a tensor handle containing a float scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
|
||||
@ -72,4 +73,11 @@ TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
|
||||
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
|
||||
const char* device_type);
|
||||
|
||||
// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
|
||||
tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
|
||||
int num_tasks);
|
||||
|
||||
// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks);
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -26,6 +28,51 @@ using tensorflow::string;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
|
||||
|
||||
static FactoriesMap& GetFactories() {
|
||||
static FactoriesMap* factories = new FactoriesMap;
|
||||
return *factories;
|
||||
}
|
||||
|
||||
static const char* default_factory = "<unset>";
|
||||
|
||||
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
assert((!GetFactories().count(name)) ||
|
||||
(GetFactories()[name] == factory) &&
|
||||
"Duplicate tracing factory registration");
|
||||
GetFactories()[name] = factory;
|
||||
}
|
||||
|
||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||
|
||||
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
auto entry = GetFactories().find(default_factory);
|
||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||
string msg = absl::StrCat(
|
||||
"No tracing engine factory has been registered with the key '",
|
||||
default_factory, "' (available: ");
|
||||
// Ensure deterministic (sorted) order in the error message
|
||||
std::set<string> factories_sorted;
|
||||
for (const auto& factory : GetFactories())
|
||||
factories_sorted.insert(factory.first);
|
||||
const char* comma = "";
|
||||
for (const string& factory : factories_sorted) {
|
||||
msg += comma + factory;
|
||||
comma = ", ";
|
||||
}
|
||||
msg += ")";
|
||||
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
//
|
||||
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
void TF_SetTracingImplementation(const char* name) {
|
||||
tensorflow::internal::SetDefaultTracingEngine(name);
|
||||
}
|
||||
|
||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||
// given tracing context.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
||||
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList* outputs, TF_Status* s) {
|
||||
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
return func;
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
return wrap(unwrap(func)->AddParameter(dtype, s));
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return wrap(unwrap(o)->outputs[i]);
|
||||
}
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status* s) {
|
||||
unwrap(o)->outputs.push_back(unwrap(tensor));
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
|
@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
|
||||
// Creates a context for tracing the execution of operations into a function.
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
||||
// This allows the client to swap the implementation of the tracing engine.
|
||||
// Any future call to TF_CreateFunction will use the implementation defined
|
||||
// here.
|
||||
void TF_SetTracingImplementation(const char* name);
|
||||
|
||||
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||
// tracing, a function can be obtained by TF_FinalizeFunction.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
|
||||
|
||||
// Creates a context for eager execution of operations.
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
// Add a new parameter to a TensorFlow Function.
|
||||
// TODO(aminim): what about shape?
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||
@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||
// an operation.
|
||||
// It just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
// TODO(aminim): the description above isn't clear with respect to
|
||||
// TF_OutputListNumOutputs and the current eager implementation which requires
|
||||
// the number of outputs to be set by the client.
|
||||
// an operation, or provided to create a function.
|
||||
// When executing an operation in an eager context, the expected number of
|
||||
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
// Prepare tracing to the expected number of output for an operation.
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
|
||||
// Return the number of outputs in the list.
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
// Return the `i`th output in the list.
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
// Append a tensor at the end of the output list, growing its size by one.
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status*);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph. The output tensors are
|
||||
@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||
// context. The returned TF_GraphToFunction must be deleted by the client.
|
||||
// context. The provided `ctx` is consumed by this API call and deleted.
|
||||
// The returned TF_AbstractFunction must be deleted by the client,
|
||||
// TODO(aminim): clarify the contract on the state of the context after this
|
||||
// call.
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList*, TF_Status*);
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
|
||||
|
@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't add function parameter on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't use finalize function on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
};
|
||||
|
||||
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
|
||||
// adding them to the graph.
|
||||
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
||||
// "execution" of operation, i.e. adding them to the function.
|
||||
class GraphContext : public ExecutionContext {
|
||||
public:
|
||||
GraphContext()
|
||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||
explicit GraphContext(const char* name)
|
||||
: ExecutionContext(kKind),
|
||||
graph_(new TF_Graph(), TF_DeleteGraph),
|
||||
name_(name) {}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext {
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
if (tf_opdesc == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||
if (!graph_tensor) {
|
||||
@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext {
|
||||
}
|
||||
}
|
||||
|
||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||
const GraphTensor* inputs, int num_outputs,
|
||||
const GraphTensor* outputs, TF_Status* status) const {
|
||||
std::vector<TF_Output> graph_inputs;
|
||||
graph_inputs.resize(num_inputs);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.resize(num_outputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
graph_inputs[i] = inputs[i].output;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
graph_outputs[i] = outputs[i].output;
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_OperationDescription* opdesc =
|
||||
TF_NewOperation(graph_.get(), "Placeholder",
|
||||
absl::StrCat("_input_", inputs_.size()).c_str());
|
||||
TF_SetAttrType(opdesc, "dtype", dtype);
|
||||
auto* operation = TF_FinishOperation(opdesc, s);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
|
||||
inputs_.push_back(TF_Output{operation, 0});
|
||||
return new GraphTensor(inputs_.back(), this);
|
||||
}
|
||||
|
||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||
graph_inputs.size(), graph_inputs.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, fn_name, status);
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.reserve(outputs->outputs.size());
|
||||
for (AbstractTensor* abstract_output : outputs->outputs) {
|
||||
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
|
||||
if (!output) {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Returning a non-graph tensor from a function has not "
|
||||
"been implemented yet.");
|
||||
return nullptr;
|
||||
}
|
||||
graph_outputs.push_back(output->output);
|
||||
}
|
||||
|
||||
func->func = TF_GraphToFunction(
|
||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
return func.release();
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext {
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
std::vector<TF_Output> inputs_;
|
||||
const char* name_;
|
||||
};
|
||||
|
||||
// Helper that converts the graph currently held in the context into a function.
|
||||
static AbstractFunction* ExecutionContextToFunction(
|
||||
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const AbstractTensor* inputs, int num_outputs,
|
||||
const AbstractTensor* outputs, TF_Status* status) {
|
||||
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
|
||||
if (graph_ctx == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"fn_body is not a TF_GraphContext.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
|
||||
if (!graph_inputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
|
||||
if (!graph_outputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
GraphFunction* func = new GraphFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||
num_outputs, graph_outputs, status);
|
||||
return func;
|
||||
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
return new GraphContext(name);
|
||||
}
|
||||
|
||||
// Register the tracing implemented in this file as the default tracing engine.
|
||||
static bool register_tracing = [] {
|
||||
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||
SetDefaultTracingEngine("graphdef");
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Graph API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::GraphContext());
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
|
||||
unwrap(inputs), num_outputs,
|
||||
unwrap(outputs), status));
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
@ -57,7 +58,7 @@ T* dyncast(S source) {
|
||||
// GraphContext and vice-versa).
|
||||
class AbstractTensor {
|
||||
protected:
|
||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
||||
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
|
||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
@ -100,7 +101,7 @@ class AbstractFunction {
|
||||
// on a given context, with the same or different input tensors.
|
||||
class AbstractOp {
|
||||
protected:
|
||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
||||
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
|
||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
@ -128,7 +129,7 @@ class AbstractOp {
|
||||
// eager implementation or to a graph implementation.
|
||||
struct ExecutionContext {
|
||||
protected:
|
||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
||||
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
|
||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
|
||||
public:
|
||||
@ -148,6 +149,17 @@ struct ExecutionContext {
|
||||
// Creates an empty AbstractOperation suitable to use with this context.
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
|
||||
|
||||
// Registers a functions with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
@ -156,6 +168,11 @@ struct ExecutionContext {
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
void SetDefaultTracingEngine(const char* name);
|
||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||
FactoryFunction factory);
|
||||
|
||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||
// C++ implementation, and back.
|
||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||
|
@ -29,7 +29,12 @@ using tensorflow::string;
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifiedCAPI, TestBasicEager) {
|
||||
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
|
||||
protected:
|
||||
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
|
||||
};
|
||||
|
||||
TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -81,33 +86,18 @@ TEST(UnifiedCAPI, TestBasicEager) {
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TestBasicGraph) {
|
||||
TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "double";
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_CreateFunction(fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
@ -123,17 +113,13 @@ TEST(UnifiedCAPI, TestBasicGraph) {
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
string fn_name = "double";
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||
TF_AbstractFunction* func =
|
||||
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractTensor(output_t);
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -174,17 +160,160 @@ TEST(UnifiedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Status* s = status.get();
|
||||
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "two_adds";
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
TF_AbstractTensor* add_output1;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output1 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Same with a second "Add" computing `arg1 + arg1`.
|
||||
TF_AbstractTensor* add_output2;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output2 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Finalize the function by providing the returned values.
|
||||
TF_AbstractFunction* func;
|
||||
{
|
||||
// We want to return the output of both add operations, create a new list
|
||||
// and populate it.
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListPushBack(func_outputs, add_output1, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_OutputListPushBack(func_outputs, add_output2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* We traced so far this function:
|
||||
*
|
||||
* def two_adds(a, b):
|
||||
* my_add1 = a + b
|
||||
* my_add2 = b + b
|
||||
* return my_add1, my_add2
|
||||
*
|
||||
* Now we will execute this function with an eager context:
|
||||
*
|
||||
* output1, output2 = two_adds(2.0, 3.0)
|
||||
*
|
||||
* and check that we got 5.0 and 6.0 as results.
|
||||
*/
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build two abstract input tensors as function arguments.
|
||||
std::vector<TF_AbstractTensor*> func_args;
|
||||
{
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
}
|
||||
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||
eager_execution_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||
|
||||
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
|
||||
float results[2];
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
|
||||
TF_DeleteTensor(f_t);
|
||||
}
|
||||
ASSERT_EQ(results[0], 5.0);
|
||||
ASSERT_EQ(results[1], 6.0);
|
||||
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
}
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -193,18 +322,15 @@ TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
|
||||
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, func);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -222,10 +348,10 @@ TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -243,7 +369,7 @@ TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -273,7 +399,8 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
@ -289,10 +416,11 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -349,5 +477,8 @@ TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Values("graphdef", "mlir"));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -101,6 +101,17 @@ class AbstractContextInterface {
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
// Add a function (serialized FunctionDef protocol buffer) so that it can
|
||||
// be executed as an op. Return error if the function with the same name
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -12,38 +12,98 @@ package(
|
||||
# need a second rule that omits .cc files, in
|
||||
# tensorflow/python:_pywrap_parallel_device.
|
||||
filegroup(
|
||||
name = "headers",
|
||||
name = "lib_headers",
|
||||
srcs = ["parallel_device_lib.h"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "lib_sources",
|
||||
srcs = ["parallel_device_lib.cc"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "device_headers",
|
||||
srcs = ["parallel_device.h"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "device_sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = [
|
||||
":device_headers",
|
||||
":lib_headers",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
srcs = [
|
||||
":device_sources",
|
||||
":lib_sources",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
srcs = [":device_sources"],
|
||||
hdrs = [":device_headers"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":parallel_device_lib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_lib",
|
||||
srcs = [":lib_sources"],
|
||||
hdrs = [":lib_headers"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_testlib",
|
||||
testonly = 1,
|
||||
srcs = ["parallel_device_testlib.cc"],
|
||||
hdrs = ["parallel_device_testlib.h"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -53,3 +113,40 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_remote_test",
|
||||
srcs = ["parallel_device_remote_test.cc"],
|
||||
# TODO(b/136478427): Enable global heap checking when servers shut down
|
||||
# cleanly.
|
||||
args = ["--heap_check=local"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||
# to TensorFlow by default, just used in a few tests.
|
||||
filegroup(
|
||||
name = "parallel_device_ops_srcs",
|
||||
srcs = ["parallel_device_ops.cc"],
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_ops",
|
||||
srcs = [":parallel_device_ops_srcs"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -23,25 +23,13 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
namespace parallel_device {
|
||||
namespace {
|
||||
|
||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) const {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class OpDeleter {
|
||||
public:
|
||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||
@ -49,180 +37,46 @@ class OpDeleter {
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
|
||||
// A representation of the custom device passed in and out of the TFE custom
|
||||
// device APIs, providing context about the parallel device to
|
||||
// ParallelDeviceExecute.
|
||||
class ParallelDevice {
|
||||
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||
// placed on the parallel device.
|
||||
class NamedParallelDevice {
|
||||
public:
|
||||
ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices);
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||
// ParallelDevice, or if the individual copies fail.
|
||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
//
|
||||
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
|
||||
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
|
||||
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
|
||||
// tensor, but other operations will implicitly broadcast non-parallel input
|
||||
// tensors across the ParallelDevice's component devices.
|
||||
//
|
||||
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
|
||||
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
|
||||
// causes `Execute` to return non-parallel tensors.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK.
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Implements the parallel case for `Execute`, where all of the outputs of the
|
||||
// operation are ParallelTensors, and all inputs are either ParallelTensors or
|
||||
// should be implicitly broadcast. This means the operation is not
|
||||
// TPUReplicatedInput or TPUReplicatedOutput.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ExecuteParallelOperation(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
const std::string& device_name() const { return device_name_; }
|
||||
NamedParallelDevice(const std::string& name,
|
||||
std::unique_ptr<ParallelDevice> parallel_device)
|
||||
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
|
||||
const std::string& name() const { return device_name_; }
|
||||
const ParallelDevice& device() const { return *parallel_device_; }
|
||||
|
||||
private:
|
||||
// The name of the parallel device
|
||||
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
|
||||
const std::string device_name_;
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
std::string device_name_;
|
||||
std::unique_ptr<ParallelDevice> parallel_device_;
|
||||
};
|
||||
|
||||
// The internal representation of a TFE_TensorHandle placed on a
|
||||
// ParallelDevice. Contains a tuple of tensors, one on each of the
|
||||
// `underlying_devices_` of the ParallelDevice.
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
|
||||
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
|
||||
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
|
||||
std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
const std::vector<TensorHandlePtr> tensors_;
|
||||
const std::vector<int64_t> shape_;
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices)
|
||||
: device_name_(name),
|
||||
underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
|
||||
if (device_name_ == current_device) {
|
||||
std::string message(absl::StrCat(
|
||||
"Tried to copy a TensorHandle to its existing device: ", device_name_));
|
||||
TF_SetStatus(status, TF_INTERNAL, message.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, underlying_device_name.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(t);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
const std::string& parallel_device_name, TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) {
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
|
||||
// TODO(allenl): We should remove "TPU" from these op names at the very least,
|
||||
// or consider other ways of packing/unpacking parallel tensors.
|
||||
if (operation_name == std::string("TPUReplicatedInput")) {
|
||||
// Special-cased operation for packing per-device tensors into one parallel
|
||||
// tensor.
|
||||
if (inputs.size() != underlying_devices_.size()) {
|
||||
if (inputs.size() != parallel_device.num_underlying_devices()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
|
||||
inputs.size()));
|
||||
"The parallel device ", parallel_device_name, " expected ",
|
||||
parallel_device.num_underlying_devices(),
|
||||
" inputs to TPUReplicatedInput, but got ", inputs.size()));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
@ -245,7 +99,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
parallel_device, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
@ -256,10 +110,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (expected_outputs != underlying_devices_.size()) {
|
||||
if (expected_outputs != parallel_device.num_underlying_devices()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(),
|
||||
"The parallel device ", parallel_device_name, " expected ",
|
||||
parallel_device.num_underlying_devices(),
|
||||
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
@ -282,10 +136,40 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(parallel_device.DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
std::vector<ParallelTensor*> parallel_inputs;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||
parallel_inputs.reserve(inputs.size());
|
||||
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
|
||||
for (const auto& input : inputs) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
parallel_device.CopyToParallelDevice(
|
||||
context, absl::get<TFE_TensorHandle*>(input), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
parallel_inputs.push_back(parallel_tensor.get());
|
||||
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
|
||||
} else {
|
||||
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
|
||||
}
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
ExecuteParallelOperation(context, std::move(inputs), operation_name,
|
||||
parallel_device.Execute(context, parallel_inputs, operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
@ -300,144 +184,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::ExecuteParallelOperation(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
per_device_outputs.reserve(first_op_output_count);
|
||||
for (int i = 0; i < first_op_output_count; ++i) {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Components of a ParallelTensor must currently all have "
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
}
|
||||
|
||||
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
|
||||
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
|
||||
// reference counts drop to zero.
|
||||
@ -445,17 +191,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<ParallelTensor*>(data);
|
||||
}
|
||||
|
||||
TensorHandlePtr ParallelTensor::AsTensorHandle(
|
||||
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status) {
|
||||
TensorHandlePtr ParallelTensorToTensorHandle(
|
||||
const std::string& parallel_device_name, TFE_Context* context,
|
||||
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
|
||||
// The resulting TensorHandle owns an opaque pointer to "device memory", which
|
||||
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
|
||||
// deleted, it will call ParallelTensorDeallocator to free the struct.
|
||||
ParallelTensor* t_released = t.release();
|
||||
const std::vector<int64_t>& shape(t_released->shape());
|
||||
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, t_released->device_.device_name().c_str(), t_released->dtype_,
|
||||
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
|
||||
&ParallelTensorDeallocator, nullptr, status));
|
||||
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
|
||||
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
|
||||
status));
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
|
||||
@ -471,12 +218,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
|
||||
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
NamedParallelDevice* named_device =
|
||||
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
const ParallelDevice& dev = named_device->device();
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
dev->CopyToParallelDevice(context, tensor, status));
|
||||
dev.CopyToParallelDevice(context, tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
|
||||
status)
|
||||
return ParallelTensorToTensorHandle(named_device->name(), context,
|
||||
std::move(parallel_tensor), status)
|
||||
.release();
|
||||
}
|
||||
|
||||
@ -510,14 +259,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
NamedParallelDevice* named_device =
|
||||
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (dev->device_name() == tensor_handle_device) {
|
||||
if (named_device->name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
@ -529,8 +279,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
|
||||
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
|
||||
*num_outputs, status));
|
||||
ExecuteWithSpecialOps(named_device->device(), named_device->name(),
|
||||
context, std::move(typed_inputs), operation_name,
|
||||
attributes, *num_outputs, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!maybe_typed_outputs.has_value()) {
|
||||
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
|
||||
@ -551,8 +302,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
|
||||
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
|
||||
} else {
|
||||
outputs[i] = ParallelTensor::AsTensorHandle(
|
||||
context,
|
||||
outputs[i] = ParallelTensorToTensorHandle(
|
||||
named_device->name(), context,
|
||||
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
|
||||
typed_output)),
|
||||
status)
|
||||
@ -569,7 +320,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void DeleteParallelDevice(void* device_info) {
|
||||
delete reinterpret_cast<ParallelDevice*>(device_info);
|
||||
delete reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -588,8 +339,10 @@ void AllocateParallelDevice(const char* device_name,
|
||||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
|
||||
std::unique_ptr<ParallelDevice> parallel_device(
|
||||
new ParallelDevice(underlying_devices_vector));
|
||||
*device_info =
|
||||
new NamedParallelDevice{device_name, std::move(parallel_device)};
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
namespace parallel_device {
|
||||
|
||||
// Allocate a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
235
tensorflow/c/eager/parallel_device/parallel_device_lib.cc
Normal file
235
tensorflow/c/eager/parallel_device/parallel_device_lib.cc
Normal file
@ -0,0 +1,235 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
namespace {
|
||||
|
||||
class OpDeleter {
|
||||
public:
|
||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||
};
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
: underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, underlying_device_name.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(t);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(TFE_Context* context,
|
||||
const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor =
|
||||
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
per_device_outputs.reserve(first_op_output_count);
|
||||
for (int i = 0; i < first_op_output_count; ++i) {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Components of a ParallelTensor must currently all have "
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
137
tensorflow/c/eager/parallel_device/parallel_device_lib.h
Normal file
137
tensorflow/c/eager/parallel_device/parallel_device_lib.h
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
|
||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) const {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||
// placed on each underlying device.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||
// ParallelDevice, or if the individual copies fail.
|
||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// The number of devices operations run on.
|
||||
size_t num_underlying_devices() const { return underlying_devices_.size(); }
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors. Wraps the
|
||||
// resulting per-device and per-output TFE_TensorHandles into one
|
||||
// ParallelTensor per output of the original operation.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
private:
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
};
|
||||
|
||||
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
|
||||
// ParallelDevice.
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
|
||||
// A generalization of the shapes of the underlying tensors.
|
||||
const std::vector<int64_t>& shape() const { return shape_; }
|
||||
TF_DataType dtype() const { return dtype_; }
|
||||
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
const std::vector<TensorHandlePtr> tensors_;
|
||||
const std::vector<int64_t> shape_;
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||
// it to core TF. Right now the eager C API does some checking of op
|
||||
// registrations before calling into custom devices, but we may be able to avoid
|
||||
// that.
|
||||
REGISTER_OP("DeviceID")
|
||||
.Output("device_id: int64")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -0,0 +1,147 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost", ":", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestRemoteBasic) {
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
std::string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
|
||||
serialized.size(), status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:worker/replica:0/task:1/device:CPU:0",
|
||||
"/job:worker/replica:0/task:2/device:CPU:0");
|
||||
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
std::string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
|
||||
serialized.size(), status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), in_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Loop to make synchronization failures more deterministic
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
TensorHandlePtr multiply_result(
|
||||
Multiply(context.get(), combined_value.get(), combined_value.get(),
|
||||
status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> out_components;
|
||||
ExtractPerDeviceValues(context.get(), multiply_result.get(),
|
||||
&out_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<float>(out_components[0].get(), 9.);
|
||||
ExpectScalarEq<float>(out_components[1].get(), 4.);
|
||||
}
|
||||
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||
@ -28,363 +29,6 @@ limitations under the License.
|
||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||
// another option.
|
||||
|
||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||
// template argument requires passing a function pointer to
|
||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
// A helper for performing common operations on variables. A much more
|
||||
// restricted stand-in for tf.Variable in Python.
|
||||
class Variable {
|
||||
public:
|
||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||
// indication of the dtype of the variable's value.
|
||||
//
|
||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||
// separate static method which returns a status.
|
||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||
: handle_(handle), type_(type) {}
|
||||
|
||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||
// object.
|
||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status);
|
||||
// Dereferences the backing buffer for the variable. Note that since this can
|
||||
// fail (it runs operations), it must be called explicitly and the resulting
|
||||
// `status` checked.
|
||||
void Destroy(TFE_Context* context, TF_Status* status);
|
||||
|
||||
// Reads from the variable.
|
||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||
// Assigns a new value to the variable.
|
||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||
// Adds `value` to the existing value of the variable.
|
||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status);
|
||||
|
||||
private:
|
||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||
// AssignSub, ...).
|
||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status);
|
||||
|
||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||
// buffer.
|
||||
TFE_TensorHandle* handle_;
|
||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||
// dtype of read operations).
|
||||
TF_DataType type_;
|
||||
};
|
||||
|
||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
// Use the special GUID for no buffer sharing
|
||||
//
|
||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||
// only reasonable way to make variables with no aliasing using the eager C
|
||||
// API.
|
||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||
no_sharing.length());
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return new Variable(var_handle, type);
|
||||
}
|
||||
|
||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||
// Free the backing buffer for the variable.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Delete the variable handle itself.
|
||||
TFE_DeleteTensorHandle(handle_);
|
||||
}
|
||||
|
||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(var_value);
|
||||
}
|
||||
|
||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), value, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
|
||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||
// deleted.
|
||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||
const int num_bytes = sizeof(float);
|
||||
float* values = new float[1];
|
||||
values[0] = v;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status) {
|
||||
const int num_bytes = v.size() * sizeof(float);
|
||||
float* values = new float[v.size()];
|
||||
memcpy(values, v.data(), num_bytes);
|
||||
int64_t dims = v.size();
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||
&FloatDeallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles[num_replicas];
|
||||
int num_retvals = num_replicas;
|
||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
(*components)[i].reset(result_handles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
TFE_OpAddInput(op.get(), components[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), second, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), first_device, status);
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device) {
|
||||
// Register the custom device
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
// device.
|
||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||
to_delete->Destroy(context, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
delete to_delete;
|
||||
};
|
||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||
status.get()),
|
||||
variable_deleter);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||
// component device.
|
||||
{
|
||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
variable->Assign(context, initial_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read from the variable and verify that we have a parallel tensor.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Add a parallel tensor with different values on each device to the variable.
|
||||
{
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value =
|
||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read the variable and verify that each component has the right modified
|
||||
// value.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -498,8 +142,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
@ -630,7 +274,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
ExpectScalarEq<float>(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
@ -644,8 +288,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
ExpectScalarEq<float>(first_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
@ -806,8 +450,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
@ -909,8 +553,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
|
308
tensorflow/c/eager/parallel_device/parallel_device_testlib.cc
Normal file
308
tensorflow/c/eager/parallel_device/parallel_device_testlib.cc
Normal file
@ -0,0 +1,308 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||
// integration testing rather than purely testing the parallel device. They
|
||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||
// another option.
|
||||
|
||||
|
||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
// Use the special GUID for no buffer sharing
|
||||
//
|
||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||
// only reasonable way to make variables with no aliasing using the eager C
|
||||
// API.
|
||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||
no_sharing.length());
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return new Variable(var_handle, type);
|
||||
}
|
||||
|
||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||
// Free the backing buffer for the variable.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Delete the variable handle itself.
|
||||
TFE_DeleteTensorHandle(handle_);
|
||||
}
|
||||
|
||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(var_value);
|
||||
}
|
||||
|
||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), value, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
|
||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||
// deleted.
|
||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||
const int num_bytes = sizeof(float);
|
||||
float* values = new float[1];
|
||||
values[0] = v;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status) {
|
||||
const int num_bytes = v.size() * sizeof(float);
|
||||
float* values = new float[v.size()];
|
||||
memcpy(values, v.data(), num_bytes);
|
||||
int64_t dims = v.size();
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||
&FloatDeallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles[num_replicas];
|
||||
int num_retvals = num_replicas;
|
||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
(*components)[i].reset(result_handles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), second, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), first_device, status);
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device) {
|
||||
// Register the custom device
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
// device.
|
||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||
to_delete->Destroy(context, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
delete to_delete;
|
||||
};
|
||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||
status.get()),
|
||||
variable_deleter);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||
// component device.
|
||||
{
|
||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
variable->Assign(context, initial_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read from the variable and verify that we have a parallel tensor.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Add a parallel tensor with different values on each device to the variable.
|
||||
{
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value =
|
||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read the variable and verify that each component has the right modified
|
||||
// value.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
// Compute the device ID twice and verify the result
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
174
tensorflow/c/eager/parallel_device/parallel_device_testlib.h
Normal file
174
tensorflow/c/eager/parallel_device/parallel_device_testlib.h
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||
// template argument requires passing a function pointer to
|
||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
// A helper for performing common operations on variables. A much more
|
||||
// restricted stand-in for tf.Variable in Python.
|
||||
class Variable {
|
||||
public:
|
||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||
// indication of the dtype of the variable's value.
|
||||
//
|
||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||
// separate static method which returns a status.
|
||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||
: handle_(handle), type_(type) {}
|
||||
|
||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||
// object.
|
||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status);
|
||||
// Dereferences the backing buffer for the variable. Note that since this can
|
||||
// fail (it runs operations), it must be called explicitly and the resulting
|
||||
// `status` checked.
|
||||
void Destroy(TFE_Context* context, TF_Status* status);
|
||||
|
||||
// Reads from the variable.
|
||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||
// Assigns a new value to the variable.
|
||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||
// Adds `value` to the existing value of the variable.
|
||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status);
|
||||
|
||||
private:
|
||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||
// AssignSub, ...).
|
||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status);
|
||||
|
||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||
// buffer.
|
||||
TFE_TensorHandle* handle_;
|
||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||
// dtype of read operations).
|
||||
TF_DataType type_;
|
||||
};
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status);
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
|
||||
|
||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status);
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status);
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status);
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device);
|
||||
|
||||
// Implementations of templated functions ******************************
|
||||
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
TFE_OpAddInput(op.get(), components[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
EXPECT_EQ(expected_value,
|
||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::parallel_device::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
30
tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
Normal file
30
tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
Normal file
@ -0,0 +1,30 @@
|
||||
# Experimental gcs filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Filesystem implementation for GCS environments
|
||||
tf_cc_shared_object(
|
||||
name = "gcs_filesystem",
|
||||
framework_so = [],
|
||||
linkstatic = False,
|
||||
per_os_targets = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":gcs_filesystem_impl"],
|
||||
)
|
||||
|
||||
# The real implementation of the filesystem.
|
||||
cc_library(
|
||||
name = "gcs_filesystem_impl",
|
||||
srcs = ["gcs_filesystem.cc"],
|
||||
copts = select({
|
||||
"//conditions:default": [],
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
],
|
||||
)
|
@ -0,0 +1,72 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_writable_file
|
||||
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||
info->plugin_memory_free = plugin_memory_free;
|
||||
info->num_schemes = 1;
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "gs");
|
||||
}
|
@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory {
|
||||
delete_function_(delete_function),
|
||||
rendezvous_builder_(rendezvous_builder) {}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
Status NewServer(const ServerDef& server_def, const Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
TF_RETURN_IF_ERROR(CGrpcServer::Create(
|
||||
server_def, init_function_, start_function_, stop_function_,
|
||||
|
@ -31,9 +31,6 @@ cc_library(
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
|
||||
# so that we can depend on c/eager/c_api_unified_experimental.h.
|
||||
features = ["-layering_check"],
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
@ -41,6 +38,8 @@ cc_library(
|
||||
":concrete_function_type",
|
||||
":function_metadata",
|
||||
":function_metadata_type",
|
||||
":tensorhandle_list",
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
@ -156,10 +155,43 @@ cc_library(
|
||||
"saved_model_api_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list",
|
||||
srcs = [
|
||||
"tensorhandle_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list_type",
|
||||
hdrs = [
|
||||
"tensorhandle_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
||||
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
|
||||
// internal header, and implement this function.
|
||||
return nullptr;
|
||||
const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
|
@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
return tensorflow::wrap(result.release());
|
||||
}
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
return tensorflow::wrap(result.release());
|
||||
}
|
||||
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) {
|
||||
delete tensorflow::unwrap(model);
|
||||
}
|
||||
|
||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
const char* function_path,
|
||||
TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetFunction(function_path, &result);
|
||||
tensorflow::unwrap(model)->GetFunction(function_path, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
|
||||
&result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
|
||||
return new TF_ConcreteFunctionList{
|
||||
tensorflow::unwrap(model)->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -18,13 +18,18 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_SavedModel {
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
||||
};
|
||||
typedef struct TF_SavedModel TF_SavedModel;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
|
||||
return tensorflow::unwrap(list)->size();
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
|
||||
int i) {
|
||||
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
|
||||
}
|
||||
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
@ -24,6 +24,7 @@ exports_files(
|
||||
"concrete_function_list.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"tensorhandle_list.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
@ -39,6 +40,7 @@ cc_library(
|
||||
":concrete_function_list",
|
||||
":function_metadata",
|
||||
":saved_model_api",
|
||||
":tensorhandle_list",
|
||||
],
|
||||
)
|
||||
|
||||
@ -61,3 +63,8 @@ alias(
|
||||
name = "saved_model_api",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorhandle_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a list of TensorHandles implicitly captured by this function.
|
||||
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function.
|
||||
|
@ -21,19 +21,27 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT size_t
|
||||
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
|
||||
TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
// Returns the `i`th TF_ConcreteFunction in the list.
|
||||
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_ConcreteFunctionList* list, int i);
|
||||
|
||||
// Deletes `list`.
|
||||
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
|
||||
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
@ -0,0 +1,43 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
|
||||
const TF_TensorHandleList* list);
|
||||
|
||||
// Returns the `i`th TFE_TensorHandle in the list.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
|
||||
const TF_TensorHandleList* list, int i);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
@ -178,7 +178,7 @@ cc_library_with_android_deps(
|
||||
name = "ops",
|
||||
srcs = ["framework/ops.cc"],
|
||||
hdrs = ["framework/ops.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -197,7 +197,7 @@ cc_library_with_android_deps(
|
||||
"framework/scope_internal.h",
|
||||
],
|
||||
hdrs = ["framework/scope.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
],
|
||||
@ -237,7 +237,7 @@ cc_library_with_android_deps(
|
||||
name = "client_session",
|
||||
srcs = ["client/client_session.cc"],
|
||||
hdrs = ["client/client_session.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
@ -275,7 +275,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/const_op.cc"],
|
||||
hdrs = ["ops/const_op.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":ops",
|
||||
@ -304,7 +304,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/while_loop.cc"],
|
||||
hdrs = ["ops/while_loop.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":cc_ops",
|
||||
|
@ -62,3 +62,17 @@ cc_library(
|
||||
"//tensorflow/c:tf_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle",
|
||||
hdrs = [
|
||||
"tensorhandle.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
":tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||
@ -40,6 +41,7 @@ class Runtime {
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||
@ -63,6 +65,7 @@ class Runtime {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Status is a wrapper around an error code and an optional error message.
|
||||
@ -57,6 +58,7 @@ class Status {
|
||||
friend class RuntimeBuilder;
|
||||
friend class Runtime;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TF_Status*, and takes ownership of it.
|
||||
explicit Status(TF_Status* status) : status_(status) {}
|
||||
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Tensor represents an n-dimensional array of values.
|
||||
@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
|
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// An opaque representation of a tensor computed/managed by the Tensorflow
|
||||
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
|
||||
// to tensors placed in memory of different devices or remote address spaces.
|
||||
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
|
||||
// from it.
|
||||
class TensorHandle {
|
||||
public:
|
||||
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
|
||||
// status->ok() will be false, and the returned Tensor must not be used.
|
||||
Tensor Resolve(Status* status);
|
||||
|
||||
// Constructs a TensorHandle from a Tensor. If an error occurred,
|
||||
// status->ok() will be false, and the returned TensorHandle must not be used.
|
||||
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
|
||||
Status* status);
|
||||
|
||||
// TensorHandle is movable, and not copyable
|
||||
TensorHandle(TensorHandle&&) = default;
|
||||
TensorHandle& operator=(TensorHandle&&) = default;
|
||||
|
||||
private:
|
||||
// Wraps a TFE_TensorHandle. Takes ownership of handle.
|
||||
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
|
||||
|
||||
// TensorHandle is not copyable
|
||||
TensorHandle(const TensorHandle&) = delete;
|
||||
TensorHandle& operator=(const TensorHandle&) = delete;
|
||||
|
||||
// Returns the underlying TFE_TensorHandle that this object wraps.
|
||||
// This object retains ownership of the pointer.
|
||||
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
|
||||
|
||||
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
|
||||
// and takes ownership of handle.
|
||||
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
|
||||
|
||||
struct TFETensorHandleDeleter {
|
||||
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
|
||||
};
|
||||
|
||||
inline Tensor TensorHandle::Resolve(Status* status) {
|
||||
TF_Tensor* tensor =
|
||||
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
|
||||
const Runtime& runtime,
|
||||
Status* status) {
|
||||
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
|
||||
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return TensorHandle(nullptr);
|
||||
}
|
||||
return TensorHandle(tensor_handle);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
@ -5,12 +5,22 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_types_test_util",
|
||||
testonly = True,
|
||||
hdrs = ["tensor_types_test_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_datatype",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensor_test",
|
||||
srcs = [
|
||||
"tensor_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
@ -19,3 +29,22 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensorhandle_test",
|
||||
srcs = [
|
||||
"tensorhandle_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/cc/experimental/base/public:tensorhandle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
@ -16,69 +16,22 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
using SimpleTypes =
|
||||
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
|
||||
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||
@ -88,11 +41,10 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
cc::Tensor tensor =
|
||||
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
@ -185,19 +137,19 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
|
||||
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
bool done = false;
|
||||
cc::Status status;
|
||||
Status status;
|
||||
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||
{
|
||||
// data_vector is a rank 1 tensor.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(data_vector.size());
|
||||
|
||||
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
done = true;
|
||||
};
|
||||
|
||||
cc::Tensor tensor =
|
||||
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
Tensor tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
/*data=*/data_vector.data(),
|
||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||
/*deleter=*/callback, &status);
|
||||
@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
using tensorflow::experimental::cc::TensorHandle;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
|
||||
// verify the expected dims, dtype, value, num bytes, and num elements.
|
||||
TYPED_TEST(ConstructScalarTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor original_tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -84,7 +84,7 @@ cc_library(
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]) + if_android([
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// FunctionMetadata stores additional function information, including
|
||||
@ -40,6 +41,7 @@ class FunctionMetadata final {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
@ -26,10 +26,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::SavedModelAPI;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unordered_set<std::string> tags = {"serve"};
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
}
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -67,13 +67,13 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:arm_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
@ -95,7 +95,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -109,12 +109,12 @@ cc_library(
|
||||
name = "llvm_targets",
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:arm_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
@ -286,9 +286,9 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:core",
|
||||
"@llvm-project//llvm:ir",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:target_base",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ load(
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
|
||||
|
||||
def tf_library(
|
||||
name,
|
||||
@ -42,7 +42,8 @@ def tf_library(
|
||||
mlir_components = "None",
|
||||
deps = None,
|
||||
tags = []):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
|
||||
math enabled on cpu.
|
||||
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
build targets:
|
||||
@ -187,7 +188,9 @@ def tf_library(
|
||||
# `find` on such an object.
|
||||
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
|
||||
|
||||
flags = tfcompile_extra_flags() + flags
|
||||
target_cpu = tfcompile_target_cpu()
|
||||
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
|
||||
flags = extra_flags + flags
|
||||
|
||||
if enable_xla_hlo_profiling:
|
||||
profiling_flag = "--xla_hlo_profile"
|
||||
@ -207,6 +210,15 @@ def tf_library(
|
||||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
default_fast_math_xla_flags = ("XLA_FLAGS='" +
|
||||
"--xla_cpu_enable_fast_math=true " +
|
||||
"--xla_cpu_fast_math_honor_nans=false " +
|
||||
"--xla_cpu_fast_math_honor_infs=false " +
|
||||
"--xla_cpu_fast_math_honor_functions=false " +
|
||||
"--xla_cpu_fast_math_honor_division=false " +
|
||||
"--xla_cpu_enable_fast_min_max=true " +
|
||||
"$${XLA_FLAGS:-}' ")
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = srcs,
|
||||
@ -216,6 +228,7 @@ def tf_library(
|
||||
function_object_file,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
@ -256,6 +269,7 @@ def tf_library(
|
||||
session_module_pb,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
|
@ -67,6 +67,8 @@ int main(int argc, char** argv) {
|
||||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
|
||||
// generally the preferred case.
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::AppendDebugOptionsFlags(&flag_list);
|
||||
|
@ -251,7 +251,7 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:graph",
|
||||
@ -505,6 +505,7 @@ cc_library(
|
||||
name = "shape_inference",
|
||||
srcs = ["shape_inference.cc"],
|
||||
hdrs = ["shape_inference.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":shape_inference_helpers",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -2034,6 +2034,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorArraySplitV3",
|
||||
"TensorArrayV3",
|
||||
"TensorArrayWriteV3",
|
||||
"TensorListConcatV2",
|
||||
"TensorListElementShape",
|
||||
"TensorListFromTensor",
|
||||
"TensorListGather",
|
||||
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorListPushBack",
|
||||
"TensorListReserve",
|
||||
"TensorListSetItem",
|
||||
"TensorListSplit",
|
||||
"TensorListStack",
|
||||
"TensorScatterAdd",
|
||||
"TensorScatterSub",
|
||||
|
@ -395,8 +395,7 @@ static void ShowXlaDeviceDeprecationWarning(
|
||||
if (absl::StrContains(compilation_device_name, "CPU") ||
|
||||
absl::StrContains(compilation_device_name, "GPU")) {
|
||||
absl::call_once(once, [] {
|
||||
LOG(WARNING)
|
||||
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
"removed in subsequent releases. Instead, use either "
|
||||
"@tf.function(experimental_compile=True) for must-compile "
|
||||
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
|
||||
|
@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
}
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
SummarizeNodeDef(node_def), ".\n");
|
||||
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
|
@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
|
||||
arg_buffers_.resize(kernel->xla_input_shapes.size());
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
const Tensor* t;
|
||||
@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
<< " not the same as on-host shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(shape);
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = arg_buffers_[i].get();
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -470,10 +468,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
DCHECK(output.buffer({output_num}).is_null())
|
||||
<< "Expected output buffer to be aliased, but it is not nil.";
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
|
@ -165,7 +165,7 @@ class XlaComputationLaunchContext {
|
||||
se::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
bool use_multiple_streams_;
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
||||
std::deque<xla::ShapedBuffer> arg_buffers_;
|
||||
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
||||
};
|
||||
|
||||
|
@ -77,10 +77,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
|
||||
"//tensorflow/compiler/mlir/tfrt:lower_tf_to_tfd_alwayslink",
|
||||
"//tensorflow/compiler/mlir/tfrt:runtime_fallback_opdefs_alwayslink",
|
||||
"//tensorflow/compiler/mlir/tfrt:tf_legalize_to_tfrt",
|
||||
"//tensorflow/compiler/mlir/tfrt:tf_to_corert",
|
||||
],
|
||||
)
|
||||
|
||||
@ -108,6 +104,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -152,7 +149,6 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tfrt:compatibility_analysis",
|
||||
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -31,7 +31,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -216,13 +216,13 @@ cc_library(
|
||||
"ir/tfl_ops.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
@ -260,6 +260,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tftext_utils",
|
||||
srcs = [
|
||||
"utils/tftext_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/tftext_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stateful_ops_utils",
|
||||
srcs = [
|
||||
@ -320,6 +339,7 @@ cc_library(
|
||||
":lstm_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
":tftext_utils",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
|
@ -32,7 +32,6 @@ struct PassConfig {
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
@ -49,13 +48,8 @@ struct PassConfig {
|
||||
llvm::ArrayRef<std::string> trim_functions_whitelist;
|
||||
// All information about quantization.
|
||||
QuantizationSpecs quant_specs;
|
||||
// If `skip_control_dialect` is true, TF executor dialect is not converted to
|
||||
// TF control dialect prior to legalization to TF Lite.
|
||||
// TODO(b/142911013): Remove flag once control dialect is removed.
|
||||
bool skip_control_dialect;
|
||||
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
|
||||
// are formed by grouping consecutive ops of the same device, under a
|
||||
// `tf_device.launch` op.
|
||||
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
||||
// ops of the same device, under a `tf_device.launch` op.
|
||||
bool form_clusters;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
|
@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||
|
||||
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
std::string node_def_str;
|
||||
if (!node_def.SerializeToString(&node_def_str)) {
|
||||
return emitError(loc, "failed to serialize tensorflow node_def"),
|
||||
llvm::None;
|
||||
}
|
||||
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
|
||||
return builder_.CreateVector(flex_builder->GetBuffer());
|
||||
}
|
||||
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
|
||||
size_t map_start = flex_builder->StartMap();
|
||||
for (const auto& pair : node_def.attr()) {
|
||||
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
|
||||
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
|
||||
std::sort(attrs.begin(), attrs.end(),
|
||||
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
|
||||
for (const Item& pair : attrs) {
|
||||
const char* key = pair.first.c_str();
|
||||
const auto& attr = pair.second;
|
||||
const ::tensorflow::AttrValue& attr = pair.second;
|
||||
switch (attr.value_case()) {
|
||||
case ::tensorflow::AttrValue::kS:
|
||||
flex_builder->String(key, attr.s());
|
||||
|
@ -424,6 +424,10 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
|
||||
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
const std::vector<uint8_t>& buffer,
|
||||
OpBuilder builder, Location loc) {
|
||||
if (buffer.empty()) {
|
||||
return errors::InvalidArgument("Constant's buffer may not be empty");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
@ -695,8 +699,6 @@ StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||
for (int32_t output : output_indices) {
|
||||
if (auto& op = defining_op[output]) {
|
||||
queue.push_back(op);
|
||||
} else {
|
||||
return errors::InvalidArgument("Output tensor doesn't have defining op");
|
||||
}
|
||||
}
|
||||
|
||||
@ -801,9 +803,17 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
|
||||
for (auto output : func_outputs) {
|
||||
bool is_constant = !is_op_output[output];
|
||||
const bool is_func_input = input_index_set.contains(output);
|
||||
bool is_constant = !is_op_output[output] && !is_func_input;
|
||||
// There are 2 cases tensor is scalar when it doesn't have a shape in
|
||||
// flatbuffer:
|
||||
// 1. `is_constant` = true, means this tensor is created from a constant op.
|
||||
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
|
||||
// tensor is function input and function input type is a scalar tensor.
|
||||
const bool shapeless_is_scalar =
|
||||
is_constant || (is_func_input && is_entry_point);
|
||||
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
|
||||
/*shapeless_are_scalars=*/is_constant,
|
||||
shapeless_is_scalar,
|
||||
/*is_constant=*/is_constant);
|
||||
if (!type_or_err.ok()) {
|
||||
emitError(func_loc, "error reading return types")
|
||||
|
@ -46,28 +46,68 @@ namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
// Returns true when the given two types have the same shape or broadcastable
|
||||
// shape within the given rank. If any given shapes are non-static, this method
|
||||
// returns true.
|
||||
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
|
||||
// Returns true when the given operand arguments have the same shape or
|
||||
// broadcastable shape within the given rank. If any given shapes are
|
||||
// non-static and maximum rank is within the given rank, this method returns
|
||||
// true.
|
||||
bool IsOperandsHaveSameShapesOrBroadcastableShape(Operation *op,
|
||||
ArrayRef<unsigned> indices,
|
||||
int max_bcast_rank) {
|
||||
// Ignore shape checking on the non-static shapes for model compatibility.
|
||||
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
|
||||
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
|
||||
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
|
||||
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
|
||||
if (indices.empty()) return true;
|
||||
|
||||
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
|
||||
return true;
|
||||
// First, it checks there are any inputs that has unknown rank.
|
||||
bool has_unknown_shape_input = false;
|
||||
bool has_same_shape = true;
|
||||
bool reach_first_known_shape = false;
|
||||
int64_t max_rank = -1;
|
||||
|
||||
ArrayRef<int64_t> pivot_shape;
|
||||
SmallVector<int64_t, 4> current_shape;
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
|
||||
rhs_shaped_type.getShape(),
|
||||
|
||||
for (unsigned index : indices) {
|
||||
ShapedType shaped_type =
|
||||
op->getOperand(index).getType().dyn_cast<ShapedType>();
|
||||
if (!shaped_type || !shaped_type.hasRank()) {
|
||||
// Marks that we have an unknown rank input.
|
||||
has_unknown_shape_input = true;
|
||||
continue;
|
||||
}
|
||||
max_rank = std::max(max_rank, shaped_type.getRank());
|
||||
if (!shaped_type.hasStaticShape()) {
|
||||
// Marks that we have an unknown shape input.
|
||||
has_unknown_shape_input = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> shape = shaped_type.getShape();
|
||||
if (!reach_first_known_shape) {
|
||||
pivot_shape = shape;
|
||||
current_shape.assign(shape.begin(), shape.end());
|
||||
reach_first_known_shape = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!pivot_shape.equals(shape)) {
|
||||
has_same_shape = false;
|
||||
}
|
||||
// Checks if all the inputs are broadcastable since they have not all the
|
||||
// same shapes.
|
||||
if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
|
||||
result_shape)) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shaped_type.getRank() <= max_bcast_rank &&
|
||||
rhs_shaped_type.getRank() <= max_bcast_rank;
|
||||
current_shape = result_shape;
|
||||
}
|
||||
|
||||
// It will treat the unknown shape inputs as acceptable inputs for model
|
||||
// compatibility unless there is an known rank that is bigger than the allowed
|
||||
// broadcast maximum rank.
|
||||
if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
|
||||
|
||||
// If all the shape is known and same, CPU kernels are able to handle inputs
|
||||
// regardless of dimension size.
|
||||
return has_same_shape || max_rank <= max_bcast_rank;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1882,7 +1922,7 @@ static LogicalResult Verify(TransposeConvOp op) {
|
||||
|
||||
auto expected_output_type =
|
||||
RankedTensorType::get(output_shape, output_type.getElementType());
|
||||
if (output_type != expected_output_type) {
|
||||
if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
|
||||
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
|
||||
expected_output_type, output_type));
|
||||
}
|
||||
@ -1966,9 +2006,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
static LogicalResult Verify(TransposeOp op) {
|
||||
auto input_type = op.x().getType().cast<ShapedType>();
|
||||
auto input_type = op.input().getType().cast<ShapedType>();
|
||||
auto perm_type = op.perm().getType().cast<ShapedType>();
|
||||
auto output_type = op.y().getType().cast<ShapedType>();
|
||||
auto output_type = op.output().getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
|
||||
if (perm_type.getNumElements() != input_type.getRank()) {
|
||||
return op.emitOpError(
|
||||
@ -2004,7 +2044,8 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
}
|
||||
auto expected_output_type =
|
||||
RankedTensorType::get(transposed_shape, input_type.getElementType());
|
||||
if (output_type != expected_output_type) {
|
||||
if (failed(
|
||||
mlir::verifyCompatibleShape(output_type, expected_output_type))) {
|
||||
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
|
||||
expected_output_type, output_type));
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
|
@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
|
@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_STRING;
|
||||
case toco::IODataType::BOOL:
|
||||
return DT_BOOL;
|
||||
case toco::IODataType::COMPLEX64:
|
||||
return DT_COMPLEX64;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
@ -175,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
|
||||
return RegisterCustomBuiltinOps(extra_tf_opdefs);
|
||||
}
|
||||
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs) {
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
|
||||
tensorflow::DataType inference_type =
|
||||
@ -209,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
flag.shape().dims().end()));
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
if (flag.has_mean_value() && flag.has_std_value()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
|
||||
inference_type));
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(),
|
||||
flag.std_value(), inference_type));
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
} else {
|
||||
node_mins->push_back(llvm::None);
|
||||
node_maxs->push_back(llvm::None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -252,7 +258,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
std::string error_message;
|
||||
auto output = mlir::openOutputFile(filename, &error_message);
|
||||
if (!error_message.empty()) {
|
||||
return errors::InvalidArgument("Failed to open file in %s.", filename);
|
||||
return errors::InvalidArgument("Failed to open file in ", filename);
|
||||
}
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
|
||||
|
@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
|
||||
|
||||
// Populate quantization specs (or not) given user specified ranges for each
|
||||
// input arrays.
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs);
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs);
|
||||
|
||||
// Convert imported MLIR file to TfLite flatbuffer.
|
||||
// This will also run relevant passes as well.
|
||||
|
@ -3,6 +3,10 @@ load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library",
|
||||
)
|
||||
load(
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -23,6 +27,7 @@ package_group(
|
||||
exports_files([
|
||||
"quantization_traits.h",
|
||||
"quantization_config.h",
|
||||
"quantization_utils.h",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
@ -34,6 +39,25 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "quantization_interfaces_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-interface-decls",
|
||||
"quantization_interface.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-interface-defs",
|
||||
"quantization_interface.cc.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "quantization.td",
|
||||
td_srcs = [
|
||||
":quantization_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "quantization_info_proto",
|
||||
srcs = [
|
||||
@ -71,9 +95,11 @@ cc_library(
|
||||
name = "quantization_lib",
|
||||
srcs = [
|
||||
"quantization_driver.cc",
|
||||
"quantization_interface.cc.inc",
|
||||
"quantization_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantization_interface.h.inc",
|
||||
"quantization_traits.h",
|
||||
"quantization_utils.h",
|
||||
],
|
||||
|
@ -49,14 +49,16 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"tfl_to_std.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lite {
|
||||
@ -38,6 +39,7 @@ namespace lite {
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const tflite::TensorType& inference_type,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
|
||||
// Apply quantization passes
|
||||
PassManager pm(module->getContext());
|
||||
TFL::QuantizationSpecs quant_specs;
|
||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
|
||||
quant_specs.post_training_quantization = true;
|
||||
quant_specs.disable_per_channel = disable_per_channel;
|
||||
|
||||
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
|
||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||
if (input_tf_type == tensorflow::DT_FLOAT) {
|
||||
emit_adaptor = true;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
||||
quant_specs.inference_type = tensorflow::DT_QUINT8;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8 ||
|
||||
input_tf_type == tensorflow::DT_INT8 ||
|
||||
input_tf_type == tensorflow::DT_INT16) {
|
||||
quant_specs.inference_type = input_tf_type;
|
||||
}
|
||||
|
||||
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
|
@ -26,11 +26,13 @@ namespace mlir {
|
||||
namespace lite {
|
||||
|
||||
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
|
||||
// The `input_type` and `output_type` can be float32/qint8/int8.
|
||||
// The `input_type`, `output_type` and `inference_type` can be
|
||||
// float32/qint8/int8/int16.
|
||||
// Return partially quantized model if `fully_quantize` is false.
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const tflite::TensorType& inference_type,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
|
@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
|
||||
|
||||
tflite::StderrReporter error_reporter;
|
||||
return mlir::lite::QuantizeModel(
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8,
|
||||
tflite::TensorType_INT8, {},
|
||||
/*disable_per_channel=*/false,
|
||||
/*fully_quantize=*/true, builder, &error_reporter);
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user