Merge branch 'master' into op_tests_16x8

This commit is contained in:
Elena Zhelezina 2020-06-15 20:07:08 +01:00 committed by GitHub
commit 2b84d3cb98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3537 changed files with 212611 additions and 72134 deletions

View File

@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
# Enable using platform specific build settings
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
build --enable_platform_specific_config
build:android --noenable_platform_specific_config
build:ios --noenable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:android --copt=-w
build:ios --copt=-w
build:linux --copt=-w
build:macos --copt=-w
build:windows --copt=/w
@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode.
build:android --cxxopt=-std=c++14
build:android --host_cxxopt=-std=c++14
build:ios --cxxopt=-std=c++14
build:ios --host_cxxopt=-std=c++14
build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14
@ -372,32 +386,32 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
@ -427,8 +441,8 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"

View File

@ -1 +1 @@
3.0.0
3.1.0

87
.github/bot_config.yml vendored Normal file
View File

@ -0,0 +1,87 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees
assignees:
- amahendrakar
- ravikyram
- Saduf2019
# A list of assignees for compiler folder
compiler_assignees:
- joker-eph
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
* For TF-GPU - See point 1
* For TF-CPU - See point 2
-----------------------------------------------------------------------------------------------
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
Make sure you are using compatible TF and CUDA versions.
Please refer following TF version and CUDA version compatibility table.
| TF | CUDA |
| :-------------: | :-------------: |
| 2.1.0 - 2.2.0 | 10.1 |
| 1.13.1 - 2.0 | 10.0 |
| 1.5.0 - 1.12.0 | 9.0 |
* If you have above configuration and using _**Windows**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
-----------------------------------------------------------------------------------------------
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
* Try Google Colab to use TensorFlow.
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
* All you need is a good internet connection and you are all set.
* Try to build TF from sources by changing CPU optimization flags.
*Please let us know if this helps.*
windows_comment: >
From the stack trace it looks like you are hitting windows path length limit.
* Try to disable path length limit on Windows 10.
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
Please let us know if this helps.

View File

@ -1,7 +1,11 @@
# TensorFlow Code of Conduct
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and our
community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, gender identity and expression, level of
experience, nationality, personal appearance, race, religion, or sexual identity
and orientation.
## Our Standards

View File

@ -4,26 +4,31 @@ https://stackoverflow.com/questions/tagged/tensorflow
If you open a GitHub issue, here is our policy:
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).
1. It must be a bug, a feature request, or a significant problem with the
documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go
[here](https://github.com/tensorflow/tensorboard/issues).
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
------------------------
### System information
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
- **Have I written custom code (as opposed to using a stock example script
provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue
happens on a mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
You can collect some of this information using our environment capture script:

View File

@ -103,17 +103,17 @@ open-source software development:
### Official Builds
Build Type | Status | Artifacts
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
Build Type | Status | Artifacts
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
### Community Supported Builds
@ -142,6 +142,7 @@ Build Type | Status
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)

View File

@ -1,3 +1,41 @@
# Release 2.3.0
## Breaking Changes
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
# Release 2.1.1
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
# Release 2.0.2
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 1.15.3
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 2.2.0
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
@ -52,89 +90,150 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
## Bug Fixes and Other Changes
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options.
* TF Core:
* `tf.constant` always creates CPU tensors irrespective of the current device context.
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Improve error message when attempting to use `None` in data-dependent control flow.
* Add `RaggedTensor.numpy()`.
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
* Allow `batch_dims==rank(indices)` in `tf.gather`.
* Add support for bfloat16 in `tf.print`.
* `tf.distribute`:
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
* `tf.keras`:
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
* Allow `pathlib.Path` paths for loading models via Keras API.
* `tf.function`/AutoGraph:
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
* `tf.lite`:
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* TFLite's unpack op now supports boolean tensor inputs.
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
* Check for large TFLite tensors.
* Fix GPU delegate crash with C++17.
* Add 5D support to TFLite `strided_slice`.
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options.
* TF Core:
* `tf.constant` always creates CPU tensors irrespective of the current
device context.
* Eager `TensorHandles` maintain a list of mirrors for any copies to local
or remote devices. This avoids any redundant copies due to op execution.
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer
experimental and is available as simply `.ref()`.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops.
Vectorizing `tf.cond` is also supported now.
* Set as much partial shape as we can infer statically within the gradient
impl of the gather op.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body
functions are stateless. This allows multiple gradients while ops to run
in parallel under distribution strategy.
* Speed up `GradientTape` in eager mode by auto-generating list of op
inputs/outputs which are unused and hence not cached for gradient
functions.
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Improve error message when attempting to use `None` in data-dependent
control flow.
* Add `RaggedTensor.numpy()`.
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow
indexing into uniform dimensions.
* Update `tf.expand_dims` to always insert the new dimension as a
non-ragged dimension.
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm`
when `ids` is ragged.
* Allow `batch_dims==rank(indices)` in `tf.gather`.
* Add support for bfloat16 in `tf.print`.
* `tf.distribute`:
* Support `embedding_column` with variable-length input features for
`MultiWorkerMirroredStrategy`.
* `tf.keras`:
* Added `experimental_aggregate_gradients` argument to
`tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom
gradient aggregation and processing aggregated gradients in custom
training loop.
* Allow `pathlib.Path` paths for loading models via Keras API.
* `tf.function`/AutoGraph:
* AutoGraph is now available in `ReplicaContext.merge_call`,
`Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in
`tf.function`. See the API docs for
`tf.autograph.experimental.set_loop_options` for additonal info.
* AutoGraph error messages now exclude frames corresponding to APIs
internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more
Grappler optimizations in TensorFlow 2.x.
* Improve automatic control dependency management of resources by allowing
resource reads to occur in parallel and synchronizing only on writes.
* Fix execution order of multiple stateful calls to `experimental_run_v2`
in `tf.function`.
* You can now iterate over `RaggedTensors` using a for loop inside
`tf.function`.
* `tf.lite`:
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android
10
* TFLite Android AARs now include the C headers and APIs are required to
use TFLite from native code.
* Refactors the delegate and delegate kernel sources to allow usage in the
linter.
* Limit delegated ops to actually supported ones if a device name is
specified or `NNAPI` CPU Fallback is disabled.
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* TFLite's unpack op now supports boolean tensor inputs.
* Microcontroller and embedded code moved from experimental to main
TensorFlow Lite folder
* Check for large TFLite tensors.
* Fix GPU delegate crash with C++17.
* Add 5D support to TFLite `strided_slice`.
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to
be accelerated.
* Fix segmentation fault when running a model with LSTM nodes using
`NNAPI` Delegate
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum
operation is a scalar.
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a
scalar.
* Expose option to limit the number of partitions that will be delegated
to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine
operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left
rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`,
`tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int`
types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to
tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204),
tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to
`tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip
package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning
rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip
package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation()
on 32-bit ARM. This ensures a deterministic early exit instead of a hard
to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to
XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable
`TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends
deterministic `tf.nn.bias_add` back-prop functionality (and therefore
also deterministic back-prop of bias-addition in Keras layers) to
include when XLA JIT compilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment
variable `TF_DETERMINSTIC_OPS` or environment variable
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
configurations led to an exception with the message "No algorithm
worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier
debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC
[#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing
error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to
`tensorflow/core/framework:*_pyclif`.
## Thanks to our Contributors

View File

@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
It is possible to write models that are secure in a sense that they can safely
process untrusted inputs assuming there are no bugs. There are two main reasons
to not rely on this: first, it is easy to write models which must not be exposed
to not rely on this: First, it is easy to write models which must not be exposed
to untrusted inputs, and second, there are bugs in any software system of
sufficient complexity. Letting users control inputs could allow them to trigger
bugs either in TensorFlow or in dependent libraries.
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
vulnerability in TensorFlow (although it would be a vulnerability of this
hypothetical system).
As a general rule, it is incorrect behavior for Tensorflow to access memory it
As a general rule, it is incorrect behavior for TensorFlow to access memory it
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
such behaviors constitute a vulnerability.

View File

@ -114,6 +114,14 @@ http_archive(
],
)
http_archive(
name = "person_detect_data",
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
urls = [
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")

View File

@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MIN_BAZEL_VERSION = '3.1.0'
_TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [
@ -484,8 +484,8 @@ def check_bazel_version(min_version, max_version):
stderr = open(os.devnull, 'wb')
curr_version = run_shell(['bazel', '--version'],
allow_non_zero = True,
stderr = stderr)
allow_non_zero=True,
stderr=stderr)
if curr_version.startswith('bazel '):
curr_version = curr_version.split('bazel ')[1]
@ -1011,17 +1011,15 @@ def set_tf_cuda_compute_capabilities(environ_cp):
default_cuda_compute_capabilities = native_cuda_compute_capabilities
ask_cuda_compute_capabilities = (
'Please specify a list of comma-separated '
'CUDA compute capabilities you want to '
'build with.\nYou can find the compute '
'capability of your device at: '
'https://developer.nvidia.com/cuda-gpus.\nPlease'
' note that each additional compute '
'capability significantly increases your '
'build time and binary size, and that '
'TensorFlow only supports compute '
'capabilities >= 3.5 [Default is: %s]: ' %
default_cuda_compute_capabilities)
'Please specify a list of comma-separated CUDA compute capabilities '
'you want to build with.\nYou can find the compute capability of your '
'device at: https://developer.nvidia.com/cuda-gpus. Each capability '
'can be specified as "x.y" or "compute_xy" to include both virtual and'
' binary GPU code, or as "sm_xy" to only include the binary '
'code.\nPlease note that each additional compute capability '
'significantly increases your build time and binary size, and that '
'TensorFlow only supports compute capabilities >= 3.5 [Default is: '
'%s]: ' % default_cuda_compute_capabilities)
tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
@ -1033,8 +1031,23 @@ def set_tf_cuda_compute_capabilities(environ_cp):
for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability)
if not m:
print('Invalid compute capability: %s' % compute_capability)
all_valid = False
# We now support sm_35,sm_50,sm_60,compute_70.
sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)',
compute_capability)
if not sm_compute_match:
print('Invalid compute capability: %s' % compute_capability)
all_valid = False
else:
ver = int(sm_compute_match.group(2))
if ver < 30:
print(
'ERROR: TensorFlow only supports small CUDA compute'
' capabilities of sm_30 and higher. Please re-specify the list'
' of compute capabilities excluding version %s.' % ver)
all_valid = False
if ver < 35:
print('WARNING: XLA does not support CUDA compute capabilities '
'lower than sm_35. Disable XLA when running on older GPUs.')
else:
ver = float(m.group(0))
if ver < 3.0:
@ -1225,7 +1238,8 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
See also
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be
@ -1368,8 +1382,13 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
try:
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
except subprocess.CalledProcessError as e:
print('Error checking bazel version: ', e.output.decode('UTF-8').strip())
raise e
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1387,7 +1406,6 @@ def main():
# Windows.
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
environ_cp['TF_NEED_MPI'] = '0'
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos():
environ_cp['TF_NEED_TENSORRT'] = '0'

View File

@ -524,10 +524,14 @@ package_group(
],
)
package_group(name = "ndarray_tensor_allow_list")
package_group(
name = "ndarray_tensor_allow_list",
packages = ["//learning/pathways/..."],
)
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
@ -537,6 +541,11 @@ package_group(
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
# Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "structured_tensor_whitelist")
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(

View File

@ -85,7 +85,7 @@ tf_cuda_library(
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:chromiumos": [
":tf_attrtype",
@ -182,7 +182,7 @@ tf_cuda_library(
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_status",
@ -216,10 +216,11 @@ tf_cuda_library(
],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
@ -232,12 +233,13 @@ cc_library(
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
deps = select({
deps = [
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tf_status_internal",
"//tensorflow/core:lib",
],
}),
@ -259,10 +261,15 @@ cc_library(
name = "tensor_interface",
hdrs = ["tensor_interface.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
}),
)
cc_library(
@ -272,7 +279,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -286,16 +293,17 @@ cc_library(
srcs = ["tf_tensor.cc"],
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = select({
deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -311,14 +319,15 @@ tf_cuda_library(
"tf_tensor_internal.h",
],
visibility = ["//tensorflow:internal"],
deps = select({
deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
@ -386,8 +395,14 @@ tf_cuda_library(
deps = [
":tf_status",
":tf_status_internal",
"//tensorflow/core:lib",
],
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
)
tf_cc_test(
@ -426,7 +441,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -457,7 +472,7 @@ tf_cuda_library(
] + select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
@ -484,7 +499,7 @@ tf_cuda_library(
":tf_status_helper",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",

View File

@ -589,14 +589,16 @@ void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response;
}
TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response;
}
@ -1384,6 +1386,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
cpp_type v; \
status->status = \
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
if (!status->status.ok()) return; \
*value = static_cast<c_type>(v); \
} \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
@ -2178,6 +2181,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
}
return new_session;
} else {
LOG(ERROR) << status->status;
DCHECK_EQ(nullptr, session);
return nullptr;
}

View File

@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
return ret;
}
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
TF_Status* status) {
auto* opts = TFE_NewContextOptions();
// Reduce GPU memory allocation, and set appropriate config options for TFE
// context.
auto* config = TF_CreateConfig(
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
10);
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
if (!status->status.ok()) {
CHECK(!config);
TFE_DeleteContextOptions(opts);
return nullptr;
}
auto* ctx = TFE_NewContextFromSession(opts, session, status);
TF_DeleteBuffer(config);
TFE_DeleteContextOptions(opts);
return ctx;
}
// TODO: retrieve the device string via TFE_ContextListDevices()
static const char DEFAULT_CPU_DEVICE[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
int tensor_id, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return nullptr;
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
shared_name.size());
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
// TODO: consider making this an unknown shape.
const int64_t* dims_ptr = nullptr;
int num_dims = 0;
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
/*num_values*/ 0, status);
if (!status->status.ok()) return nullptr;
int num_retvals = 1;
TFE_TensorHandle* queue = nullptr;
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
if (!status->status.ok()) return nullptr;
CHECK_EQ(num_retvals, 1);
return queue;
}
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
TF_Status* status) {
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return;
TFE_OpAddInput(op, queue, status);
if (!status->status.ok()) return;
TFE_OpAddInput(op, tensor, status);
if (!status->status.ok()) return;
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
TFE_OpSetAttrInt(op, "timeout_ms", -1);
int num_retvals = 0;
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
if (!status->status.ok()) return;
CHECK_EQ(num_retvals, 0);
}
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
TF_DataType inputType,
TFE_TensorHandle* queue,
TF_Status* status) {
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return nullptr;
TFE_OpAddInput(op, queue, status);
if (!status->status.ok()) return nullptr;
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
TFE_OpSetAttrInt(op, "timeout_ms", -1);
TFE_TensorHandle* ret;
int num_retvals = 1;
TFE_Execute(op, &ret, &num_retvals, status);
if (!status->status.ok()) return nullptr;
CHECK_EQ(num_retvals, 1);
return ret;
}
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
TF_DataType inputType,
TF_Status* status) {
assert(session);
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
return ret;
}
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
TF_DataType inputType,
TF_Status* status) {
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
return ret;
}
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
TFE_TensorHandle* tensor, TF_Status* status) {
assert(session);
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, inputType, queue, tensor, status);
}
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status) {
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, inputType, queue, tensor, status);
}
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
TFE_TensorHandle* tensor, TF_Status* status) {
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
}
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
TF_Status* status) {
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
const TF_DataType* values, int num_values) {
auto iter = builder->attr_names.insert(attr_name).first;
builder->Set(
(*iter).c_str(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values),
num_values));
}
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,

View File

@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
// Create a serialized tensorflow.ServerDef proto.
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
// TODO: remove this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
// Creates from `session` a new eager context to run a graph function or
// sends/recvs, so that these concurrent TFE executions can share (via
// `session` and its associated device mgr) the same set of fifo queue resource
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
// graph function execution can access the same fifo queue resource handles
// (associated with devices managed by the device manager, which can be obtained
// from `session`).
//
// TODO: Remove this function once we migrate away from using session.
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
TF_Session* session, TF_Status* status);
// TODO: Retire this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
TF_Session* session, int tensor_id, TF_DataType inputType,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
TF_Status* status);
// TODO: consider folding the 2 APIs below into the ones above.
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);

View File

@ -35,7 +35,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":context_interface",
@ -144,6 +144,24 @@ cc_library(
],
)
cc_library(
name = "c_api_unified_internal",
hdrs = [
"c_api_unified_experimental_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
@ -184,7 +202,6 @@ cc_library(
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -319,6 +336,7 @@ tf_cuda_cc_test(
tags = [
"noguitar", # TODO(b/155445984): flaky
#"guitar",
"notap", # TODO(b/156981931): flaky
"multi_gpu",
],
deps = [
@ -349,7 +367,10 @@ tf_cuda_cc_test(
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",
@ -357,10 +378,46 @@ tf_cuda_cc_test(
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_distributed_test",
size = "small",
srcs = [
"c_api_distributed_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
@ -376,7 +433,10 @@ tf_cuda_cc_test(
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",
@ -412,7 +472,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api",
@ -448,6 +508,8 @@ tf_cuda_library(
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@ -505,6 +567,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",

View File

@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) {
}
#if !defined(IS_MOBILE_PLATFORM)
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
const tensorflow::ServerDef& server_def) {
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
return false;
}
return server_def.default_session_config().SerializeAsString() ==
context->session_options().config.SerializeAsString();
}
tensorflow::Status AddRemoteDevicesToMgr(
const std::vector<string>& added_remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
const tensorflow::DeviceMgr* device_mgr =
AreLocalDevicesCompatible(context, server_def)
? context->local_device_mgr()
: nullptr;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
server_def, {device_mgr}, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
@ -727,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::GetDefaultCustomKernelCreator()));
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
TF_Session* sess, TF_Status* status) {
const tensorflow::DeviceMgr* device_mgr = nullptr;
status->status = sess->session->LocalDeviceManager(&device_mgr);
if (!status->status.ok()) return nullptr;
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator()));
}
void TFE_DeleteContext(TFE_Context* ctx) {
if (ctx == nullptr) {
return;
@ -899,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors();
status->status = tensorflow::unwrap(ctx)->AsyncWait();
#endif // !IS_MOBILE_PLATFORM
}
@ -1403,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name);
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
@ -1485,14 +1473,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second);
}
destination->CopyAttributes(*tensorflow::unwrap(attrs));
}
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,

View File

@ -30,26 +30,6 @@ namespace {
using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
@ -430,4 +410,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
TestRemoteExecuteUpdateServerDefWithFailures(true);
}
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
const string first_name =
keep_localhost_for_first_connect ? "localhost" : "abc";
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
tensorflow::Status status2;
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
// Rename local device
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev1_name =
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
// Another renaming of local device
const string second_name = "def";
server_def.set_job_name(second_name);
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
absl::StrCat(second_name, ":",
tensorflow::testing::PickUnusedPortOrDie());
serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle2, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteTensorHandle(var_handle2);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
tensorflow::unsetenv("GRPC_FAIL_FAST");
}
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
} // namespace

View File

@ -0,0 +1,506 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
} // namespace

View File

@ -19,37 +19,22 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@ -351,294 +336,4 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
/*heavy_load_on_streaming_rpc=*/true);
}
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
} // namespace

View File

@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
using tensorflow::string;
@ -296,3 +298,23 @@ bool GetDeviceName(TFE_Context* ctx, string* device_name,
TF_DeleteDeviceList(devices);
return false;
}
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
// Return a tensor handle containing a float scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
@ -72,4 +73,11 @@ TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
const char* device_type);
// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
int num_tasks);
// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(int num_tasks);
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -26,6 +28,51 @@ using tensorflow::string;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
namespace tensorflow {
namespace internal {
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
static FactoriesMap& GetFactories() {
static FactoriesMap* factories = new FactoriesMap;
return *factories;
}
static const char* default_factory = "<unset>";
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) ||
(GetFactories()[name] == factory) &&
"Duplicate tracing factory registration");
GetFactories()[name] = factory;
}
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '",
default_factory, "' (available: ");
// Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted;
for (const auto& factory : GetFactories())
factories_sorted.insert(factory.first);
const char* comma = "";
for (const string& factory : factories_sorted) {
msg += comma + factory;
comma = ", ";
}
msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
} // end namespace internal
} // end namespace tensorflow
// =============================================================================
// Public C API entry points
//
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
//
// =============================================================================
void TF_SetTracingImplementation(const char* name) {
tensorflow::internal::SetDefaultTracingEngine(name);
}
// Creates a new TensorFlow function, it is an execution context attached to a
// given tracing context.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
}
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList* outputs, TF_Status* s) {
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
TF_DeleteExecutionContext(ctx);
return func;
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
return wrap(unwrap(func)->AddParameter(dtype, s));
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return wrap(unwrap(o)->outputs[i]);
}
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status* s) {
unwrap(o)->outputs.push_back(unwrap(tensor));
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {

View File

@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
// Creates a context for tracing the execution of operations into a function.
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
// This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined
// here.
void TF_SetTracingImplementation(const char* name);
// Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing
// tracing, a function can be obtained by TF_FinalizeFunction.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
// Creates a context for eager execution of operations.
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Add a new parameter to a TensorFlow Function.
// TODO(aminim): what about shape?
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
// an operation.
// It just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
// TODO(aminim): the description above isn't clear with respect to
// TF_OutputListNumOutputs and the current eager implementation which requires
// the number of outputs to be set by the client.
// an operation, or provided to create a function.
// When executing an operation in an eager context, the expected number of
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
// Prepare tracing to the expected number of output for an operation.
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
// Return the number of outputs in the list.
int TF_OutputListNumOutputs(TF_OutputList* o);
// Return the `i`th output in the list.
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Append a tensor at the end of the output list, growing its size by one.
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph. The output tensors are
@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_ExecutionContext* ctx, TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The returned TF_GraphToFunction must be deleted by the client.
// context. The provided `ctx` is consumed by this API call and deleted.
// The returned TF_AbstractFunction must be deleted by the client,
// TODO(aminim): clarify the contract on the state of the context after this
// call.
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList*, TF_Status*);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);

View File

@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
}
}
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't add function parameter on an eager context.");
return nullptr;
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't use finalize function on an eager context.");
return nullptr;
}
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s);
if (!func) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
static constexpr AbstractFunctionKind kKind = kGraphFunc;
};
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
// adding them to the graph.
// GraphContext wraps a TF_Graph modeling a single function and manages the
// "execution" of operation, i.e. adding them to the function.
class GraphContext : public ExecutionContext {
public:
GraphContext()
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
explicit GraphContext(const char* name)
: ExecutionContext(kKind),
graph_(new TF_Graph(), TF_DeleteGraph),
name_(name) {}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext {
return;
}
auto* tf_opdesc = graph_op->op_.release();
if (tf_opdesc == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
return;
}
for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) {
@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext {
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const GraphTensor* inputs, int num_outputs,
const GraphTensor* outputs, TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_OperationDescription* opdesc =
TF_NewOperation(graph_.get(), "Placeholder",
absl::StrCat("_input_", inputs_.size()).c_str());
TF_SetAttrType(opdesc, "dtype", dtype);
auto* operation = TF_FinishOperation(opdesc, s);
if (!s->status.ok()) return nullptr;
inputs_.push_back(TF_Output{operation, 0});
return new GraphTensor(inputs_.back(), this);
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
std::unique_ptr<GraphFunction> func(new GraphFunction);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = inputs[i].output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = outputs[i].output;
graph_outputs.reserve(outputs->outputs.size());
for (AbstractTensor* abstract_output : outputs->outputs) {
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
if (!output) {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
return nullptr;
}
graph_outputs.push_back(output->output);
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
func->func = TF_GraphToFunction(
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
return func.release();
}
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext {
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_;
const char* name_;
};
// Helper that converts the graph currently held in the context into a function.
static AbstractFunction* ExecutionContextToFunction(
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const AbstractTensor* inputs, int num_outputs,
const AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
if (!graph_inputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
return nullptr;
}
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
if (!graph_outputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
return nullptr;
}
GraphFunction* func = new GraphFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
num_outputs, graph_outputs, status);
return func;
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
return new GraphContext(name);
}
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef");
return true;
}();
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Graph API.
// =============================================================================
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
return wrap(new tensorflow::internal::GraphContext());
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
unwrap(inputs), num_outputs,
unwrap(outputs), status));
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
@ -57,7 +58,7 @@ T* dyncast(S source) {
// GraphContext and vice-versa).
class AbstractTensor {
protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public:
@ -100,7 +101,7 @@ class AbstractFunction {
// on a given context, with the same or different input tensors.
class AbstractOp {
protected:
enum AbstractOpKind { kGraphOp, kEagerOp };
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public:
@ -128,7 +129,7 @@ class AbstractOp {
// eager implementation or to a graph implementation.
struct ExecutionContext {
protected:
enum ExecutionContextKind { kGraphContext, kEagerContext };
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public:
@ -148,6 +149,17 @@ struct ExecutionContext {
// Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0;
// Add a function parameter and return the corresponding tensor.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
// Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
@ -156,6 +168,11 @@ struct ExecutionContext {
const ExecutionContextKind k;
};
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \

View File

@ -29,7 +29,12 @@ using tensorflow::string;
namespace tensorflow {
namespace {
TEST(UnifiedCAPI, TestBasicEager) {
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
protected:
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
};
TEST_P(UnifiedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -81,33 +86,18 @@ TEST(UnifiedCAPI, TestBasicEager) {
TF_DeleteExecutionContext(ctx);
}
TEST(UnifiedCAPI, TestBasicGraph) {
TEST_P(UnifiedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
// Start a new function / execution context.
string fn_name = "double";
TF_ExecutionContext* graph_ctx =
TF_CreateFunction(fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
@ -123,17 +113,13 @@ TEST(UnifiedCAPI, TestBasicGraph) {
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
// Clean up operation and inputs.
TF_DeleteAbstractOp(add_op);
string fn_name = "double";
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
TF_AbstractFunction* func =
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractTensor(output_t);
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -174,17 +160,160 @@ TEST(UnifiedCAPI, TestBasicGraph) {
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TF_DeleteAbstractTensor(final_result);
TF_DeleteTensor(f_t);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
}
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Status* s = status.get();
// Start a new function / execution context.
string fn_name = "two_adds";
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
TF_AbstractTensor* add_output1;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output1 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Same with a second "Add" computing `arg1 + arg1`.
TF_AbstractTensor* add_output2;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output2 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Finalize the function by providing the returned values.
TF_AbstractFunction* func;
{
// We want to return the output of both add operations, create a new list
// and populate it.
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListPushBack(func_outputs, add_output1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, add_output2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteOutputList(func_outputs);
}
/**
* We traced so far this function:
*
* def two_adds(a, b):
* my_add1 = a + b
* my_add2 = b + b
* return my_add1, my_add2
*
* Now we will execute this function with an eager context:
*
* output1, output2 = two_adds(2.0, 3.0)
*
* and check that we got 5.0 and 6.0 as results.
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build two abstract input tensors as function arguments.
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
}
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(func_outputs, 2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
eager_execution_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
float results[2];
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
TF_DeleteTensor(f_t);
}
ASSERT_EQ(results[0], 5.0);
ASSERT_EQ(results[1], 6.0);
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TF_DeleteAbstractTensor(result);
}
TF_DeleteOutputList(func_outputs);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteAbstractFunction(func);
}
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -193,18 +322,15 @@ TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
ASSERT_EQ(nullptr, func);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteExecutionContext(ctx);
}
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -222,10 +348,10 @@ TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -243,7 +369,7 @@ TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -273,7 +399,8 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
@ -289,10 +416,11 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -349,5 +477,8 @@ TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx);
}
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
} // namespace
} // namespace tensorflow

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -84,11 +84,10 @@ class AbstractContextInterface {
// Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0;
// Load a SavedModelAPI object from the given directory and tags
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
tensorflow::Status* status) = 0;
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
// Runtime. This is necessary to decouple runtime-dependent
// code that is layered on top of the runtime.
virtual bool UsesTFRT() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
@ -101,6 +100,17 @@ class AbstractContextInterface {
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
// Add a function (serialized FunctionDef protocol buffer) so that it can
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -12,39 +12,98 @@ package(
# need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device.
filegroup(
name = "headers",
name = "lib_headers",
srcs = ["parallel_device_lib.h"],
)
filegroup(
name = "lib_sources",
srcs = ["parallel_device_lib.cc"],
)
filegroup(
name = "device_headers",
srcs = ["parallel_device.h"],
)
filegroup(
name = "device_sources",
srcs = ["parallel_device.cc"],
)
filegroup(
name = "headers",
srcs = [
":device_headers",
":lib_headers",
],
visibility = ["//tensorflow/python:__pkg__"],
)
filegroup(
name = "sources",
srcs = ["parallel_device.cc"],
srcs = [
":device_sources",
":lib_sources",
],
visibility = ["//tensorflow/python:__pkg__"],
)
cc_library(
name = "parallel_device",
srcs = [":sources"],
hdrs = [":headers"],
srcs = [":device_sources"],
hdrs = [":device_headers"],
visibility = ["//tensorflow:internal"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "parallel_device_lib",
srcs = [":lib_sources"],
hdrs = [":lib_headers"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "parallel_device_testlib",
testonly = 1,
srcs = ["parallel_device_testlib.cc"],
hdrs = ["parallel_device_testlib.h"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -55,6 +114,27 @@ tf_cc_test(
],
)
tf_cc_test(
name = "parallel_device_remote_test",
srcs = ["parallel_device_remote_test.cc"],
# TODO(b/136478427): Enable global heap checking when servers shut down
# cleanly.
args = ["--heap_check=local"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(

View File

@ -23,25 +23,13 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
namespace parallel_device {
namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
@ -49,224 +37,46 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
// placed on the parallel device.
class NamedParallelDevice {
public:
ParallelDevice(const std::string& name,
const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
NamedParallelDevice(const std::string& name,
std::unique_ptr<ParallelDevice> parallel_device)
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
const std::string& name() const { return device_name_; }
const ParallelDevice& device() const { return *parallel_device_; }
private:
// The name of the parallel device
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
std::string device_name_;
std::unique_ptr<ParallelDevice> parallel_device_;
};
// The internal representation of a TFE_TensorHandle placed on a
// ParallelDevice. Contains a tuple of tensors, one on each of the
// `underlying_devices_` of the ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
const ParallelDevice& parallel_device,
const std::string& parallel_device_name, TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
if (inputs.size() != underlying_devices_.size()) {
if (inputs.size() != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
inputs.size()));
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" inputs to TPUReplicatedInput, but got ", inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
@ -289,7 +99,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
parallel_device, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
@ -300,10 +110,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) {
if (expected_outputs != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(),
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
@ -329,15 +139,38 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status));
result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
parallel_inputs.reserve(inputs.size());
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
for (const auto& input : inputs) {
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
std::unique_ptr<ParallelTensor> parallel_tensor(
parallel_device.CopyToParallelDevice(
context, absl::get<TFE_TensorHandle*>(input), status));
if (TF_GetCode(status) != TF_OK) return result;
parallel_inputs.push_back(parallel_tensor.get());
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
} else {
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
}
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
parallel_device.Execute(context, parallel_inputs, operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
@ -351,144 +184,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
@ -496,17 +191,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
TF_Status* status) {
TensorHandlePtr ParallelTensorToTensorHandle(
const std::string& parallel_device_name, TFE_Context* context,
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_,
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
&ParallelTensorDeallocator, nullptr, status));
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
@ -522,12 +218,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
const ParallelDevice& dev = named_device->device();
std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status));
dev.CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
status)
return ParallelTensorToTensorHandle(named_device->name(), context,
std::move(parallel_tensor), status)
.release();
}
@ -561,14 +259,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) {
if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
@ -580,8 +279,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
*num_outputs, status));
ExecuteWithSpecialOps(named_device->device(), named_device->name(),
context, std::move(typed_inputs), operation_name,
attributes, *num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
@ -602,8 +302,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
outputs[i] = ParallelTensor::AsTensorHandle(
context,
outputs[i] = ParallelTensorToTensorHandle(
named_device->name(), context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
@ -620,7 +320,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info);
delete reinterpret_cast<NamedParallelDevice*>(device_info);
}
} // namespace
@ -639,8 +339,10 @@ void AllocateParallelDevice(const char* device_name,
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
std::unique_ptr<ParallelDevice> parallel_device(
new ParallelDevice(underlying_devices_vector));
*device_info =
new NamedParallelDevice{device_name, std::move(parallel_device)};
}
} // namespace eager
} // namespace parallel_device
} // namespace tensorflow

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace eager {
namespace parallel_device {
// Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info);
} // namespace eager
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,376 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace parallel_device {
namespace {
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class StatusDeleter {
public:
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
};
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
} // namespace
// Allows a single op at a time to be launched without blocking.
//
// DeviceThread itself is thread-safe, in that StartExecute will block if there
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
// multiple DeviceThreads should always be accessed in the same order to avoid
// deadlocks.
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
std::bind(&DeviceThread::Run, this))) {}
~DeviceThread();
// Requests that the worker thread execute the specified operation. Blocks
// until the previously pending operation (a StartExecute without a Join) has
// finished, if any.
void StartExecute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs);
// Block until the previous `StartExecute` operation has executed. Forwards
// the status from `TFE_Execute` and returns outputs if the status is OK.
std::vector<TensorHandlePtr> Join(TF_Status* status);
private:
void Run();
void Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
enum class ExecutionState {
kReadyToExecute,
kHasResult,
kIdle,
kShuttingDown,
};
tensorflow::mutex execution_mutex_;
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
ExecutionState::kIdle;
// Tells the worker thread that there is new work.
tensorflow::condition_variable start_execute_;
// The worker thread notifies that work has finished.
tensorflow::condition_variable finished_execute_;
// Notifies a StartExecute that the previous Join has finished.
tensorflow::condition_variable finished_join_;
// Temporary state between `StartExecute` and `Join`.
// Inputs
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
DeviceThread::~DeviceThread() {
{
tensorflow::mutex_lock l(execution_mutex_);
execution_state_ = ExecutionState::kShuttingDown;
}
start_execute_.notify_one();
}
void DeviceThread::Run() {
while (true) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ == ExecutionState::kIdle ||
execution_state_ == ExecutionState::kHasResult) {
start_execute_.wait(l);
}
if (execution_state_ == ExecutionState::kShuttingDown) {
return;
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
// op_outputs_ may have been std::moved
op_outputs_ = std::vector<TensorHandlePtr>();
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
expected_max_outputs_, &op_outputs_, status_.get());
execution_state_ = ExecutionState::kHasResult;
}
}
finished_execute_.notify_one();
}
}
void DeviceThread::StartExecute(TFE_Context* context,
const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kIdle) {
// If there's already a pending execution, wait until Join finishes before
// starting on the next operation.
finished_join_.wait(l);
}
context_ = context;
operation_name_ = operation_name;
op_inputs_ = inputs;
attributes_ = attributes;
expected_max_outputs_ = expected_max_outputs;
execution_state_ = ExecutionState::kReadyToExecute;
}
start_execute_.notify_one();
}
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
std::vector<TensorHandlePtr> result;
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kHasResult) {
finished_execute_.wait(l);
}
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
}
finished_join_.notify_one();
return result;
}
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
} else {
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
}
TFE_OpAddAttrs(op_.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
TFE_OpAddInput(op_.get(), inputs[input_index], status);
if (TF_GetCode(status) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
if (TF_GetCode(status) != TF_OK) return;
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
if (TF_GetCode(status) != TF_OK) return;
unwrapped_results.resize(real_num_outputs);
outputs->reserve(real_num_outputs);
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
outputs->emplace_back(unwrapped_result);
}
}
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str()));
}
}
// Necessary for a unique_ptr to a forward-declared type.
ParallelDevice::~ParallelDevice() = default;
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
const std::vector<ParallelTensor*>& inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
int first_op_output_count = 0;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
std::vector<TFE_TensorHandle*> device_inputs;
device_inputs.reserve(device_inputs.size());
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
// Parallel tensors are divided between operations by device.
device_inputs.push_back(inputs[input_index]->tensor(device_index));
}
device_thread->StartExecute(context, operation_name,
std::move(device_inputs), attributes,
expected_max_outputs);
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
if (device_index == 0) {
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -0,0 +1,141 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace parallel_device {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ParallelTensor;
class DeviceThread;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
~ParallelDevice();
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// The number of devices operations run on.
size_t num_underlying_devices() const { return underlying_devices_.size(); }
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors. Wraps the
// resulting per-device and per-output TFE_TensorHandles into one
// ParallelTensor per output of the original operation.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
private:
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of thread wrappers, one per device, for executing operations in
// parallel.
//
// Conceptually this is a thread pool with one thread per device. It requires
// less synchronization than a thread pool would for this task, since Execute
// acquires each thread in order (and so only one Execute will schedule
// blocking collective operations at a time), and avoids some dynamic
// allocation/scheduling.
//
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
// than a single list of threads so aliased nested parallel devices don't
// re-use a thread.
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
};
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
// ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
// A generalization of the shapes of the underlying tensors.
const std::vector<int64_t>& shape() const { return shape_; }
TF_DataType dtype() const { return dtype_; }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_

View File

@ -0,0 +1,147 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost", ":", port)});
}
return server_def;
}
TEST(PARALLEL_DEVICE, TestRemoteBasic) {
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
// This server def has the task index set to 0.
std::string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
serialized.size(), status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:worker/replica:0/task:1/device:CPU:0",
"/job:worker/replica:0/task:2/device:CPU:0");
worker_server1.release();
worker_server2.release();
}
TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
// This server def has the task index set to 0.
std::string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
serialized.size(), status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), in_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Loop to make synchronization failures more deterministic
for (int i = 0; i < 100; ++i) {
TensorHandlePtr multiply_result(
Multiply(context.get(), combined_value.get(), combined_value.get(),
status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> out_components;
ExtractPerDeviceValues(context.get(), multiply_result.get(),
&out_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(out_components[0].get(), 9.);
ExpectScalarEq<float>(out_components[1].get(), 4.);
}
worker_server1.release();
worker_server2.release();
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
@ -28,390 +29,6 @@ limitations under the License.
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Assert that `handle` is equal to `expected_value`.
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::eager::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -790,7 +407,7 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
return TensorHandlePtr(result_handle);
}
TEST(PARALLEL_DEVICE, TestCollective) {
void TestCollective(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -806,6 +423,9 @@ TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
TFE_NewExecutor(async), TFE_DeleteExecutor);
TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{
@ -835,8 +455,16 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
// Destroying the context's default executor first isn't safe.
context.reset();
}
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
// Note that ops on the parallel device currently don't execute
// asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size,
TF_Status* status) {

View File

@ -0,0 +1,308 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
// integration testing rather than purely testing the parallel device. They
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}

View File

@ -0,0 +1,174 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status);
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status);
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status);
// Assert that `handle` is equal to `expected_value`.
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status);
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device);
// Implementations of templated functions ******************************
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::parallel_device::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_

View File

@ -0,0 +1,31 @@
# Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for GCS environments
tf_cc_shared_object(
name = "gcs_filesystem",
framework_so = [],
linkstatic = False,
per_os_targets = 1,
visibility = ["//visibility:public"],
deps = [":gcs_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
],
)

View File

@ -0,0 +1,101 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
// same integer values. See
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
static inline void TF_SetStatusFromGCSStatus(
const google::cloud::Status& gcs_status, TF_Status* status) {
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
gcs_status.message().c_str());
}
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(vnvo2409): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(vnvo2409): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(vnvo2409): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
(*gcs_client) = client.value();
TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 1;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "gs");
}

View File

@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory {
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def,
Status NewServer(const ServerDef& server_def, const Options& options,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,

View File

@ -57,6 +57,7 @@ cc_library(
":concrete_function",
":saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
Status TFSavedModelAPIImpl::Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out) {
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently");

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class TFSavedModelAPIImpl : public SavedModelAPI {
public:
TFSavedModelAPIImpl() = default;
Status GetFunction(const std::string& function_path,
ConcreteFunction** function) override;
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
static Status Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out);
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPIImpl() override = default;
private:
TFSavedModelAPIImpl() = default;
std::vector<ConcreteFunction> functions_;
};

View File

@ -144,7 +144,9 @@ cc_library(
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)
@ -155,6 +157,7 @@ cc_library(
"saved_model_api_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
],
)

View File

@ -22,11 +22,15 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
@ -34,14 +38,25 @@ extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
TF_Status* status) {
std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, absl::nullopt,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
@ -54,23 +69,36 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
tagset.insert(std::string(tags[i]));
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, tagset,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
void TF_DeleteSavedModel(TF_SavedModel* model) {
delete tensorflow::unwrap(model);
}
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
const char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result);
tensorflow::unwrap(model)->GetFunction(function_path, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -82,7 +110,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
&result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -91,7 +120,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
return new TF_ConcreteFunctionList{
tensorflow::unwrap(model)->ListFunctions()};
}
} // end extern "C"

View File

@ -18,13 +18,18 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_SavedModel {
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
};
typedef struct TF_SavedModel TF_SavedModel;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_

View File

@ -32,8 +32,5 @@ TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
void TF_DeleteTensorHandleList(const TF_TensorHandleList* list) {
delete tensorflow::unwrap(list);
}
} // end extern "C"

View File

@ -36,10 +36,6 @@ TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
const TF_TensorHandleList* list, int i);
// Deletes `list`.
TF_CAPI_EXPORT extern void TF_DeleteTensorHandleList(
const TF_TensorHandleList* list);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus

View File

@ -84,7 +84,7 @@ cc_library(
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
]) + if_android([
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
]),
)
@ -106,6 +106,7 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":constants",
":loader_util",
":reader",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
@ -132,6 +133,17 @@ cc_library(
],
)
cc_library(
name = "loader_util",
srcs = ["loader_util.cc"],
hdrs = ["loader_util.h"],
deps = [":constants"] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
]),
)
tf_cc_test(
name = "bundle_v2_test",
srcs = ["bundle_v2_test.cc"],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
return Status::OK();
}
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session);
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(),
init_op_name));

View File

@ -0,0 +1,90 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader_util.h"
#include <vector>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name);
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_

View File

@ -67,13 +67,13 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]),
)
@ -94,8 +94,8 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:Support", # fixdeps: keep
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
],
)
@ -109,12 +109,12 @@ cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]),
)
@ -286,9 +286,9 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:core",
"@llvm-project//llvm:support",
"@llvm-project//llvm:target",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
],
)

View File

@ -1,5 +1,5 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s
# Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and

View File

@ -20,7 +20,7 @@ load(
"tf_cc_test",
"tf_copts",
)
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
def tf_library(
name,
@ -42,7 +42,8 @@ def tf_library(
mlir_components = "None",
deps = None,
tags = []):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
math enabled on cpu.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
@ -187,7 +188,9 @@ def tf_library(
# `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
flags = tfcompile_extra_flags() + flags
target_cpu = tfcompile_target_cpu()
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
flags = extra_flags + flags
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
@ -207,6 +210,15 @@ def tf_library(
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
default_fast_math_xla_flags = ("XLA_FLAGS='" +
"--xla_cpu_enable_fast_math=true " +
"--xla_cpu_fast_math_honor_nans=false " +
"--xla_cpu_fast_math_honor_infs=false " +
"--xla_cpu_fast_math_honor_functions=false " +
"--xla_cpu_fast_math_honor_division=false " +
"--xla_cpu_enable_fast_min_max=true " +
"$${XLA_FLAGS:-}' ")
native.genrule(
name = ("gen_" + name),
srcs = srcs,
@ -216,6 +228,7 @@ def tf_library(
function_object_file,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
@ -256,6 +269,7 @@ def tf_library(
session_module_pb,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +

View File

@ -67,6 +67,8 @@ int main(int argc, char** argv) {
flags.entry_point = "entry";
flags.debug_info_path_begin_marker = "";
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
// generally the preferred case.
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::AppendDebugOptionsFlags(&flag_list);

View File

@ -251,7 +251,7 @@ cc_library(
visibility = [":friends"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:graph",
@ -505,6 +505,7 @@ cc_library(
name = "shape_inference",
srcs = ["shape_inference.cc"],
hdrs = ["shape_inference.h"],
visibility = [":friends"],
deps = [
":shape_inference_helpers",
"//tensorflow/compiler/xla:statusor",

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list;
absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")});
"element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags;
}
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names;
};
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//

View File

@ -2034,6 +2034,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorArraySplitV3",
"TensorArrayV3",
"TensorArrayWriteV3",
"TensorListConcatV2",
"TensorListElementShape",
"TensorListFromTensor",
"TensorListGather",
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorListPushBack",
"TensorListReserve",
"TensorListSetItem",
"TensorListSplit",
"TensorListStack",
"TensorScatterAdd",
"TensorScatterSub",

View File

@ -395,12 +395,11 @@ static void ShowXlaDeviceDeprecationWarning(
if (absl::StrContains(compilation_device_name, "CPU") ||
absl::StrContains(compilation_device_name, "GPU")) {
absl::call_once(once, [] {
LOG(WARNING)
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});
}
}

View File

@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def), ".\n");
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message =

View File

@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs(
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
// Pass remaining parameters.
const Tensor* t;
@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs(
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = arg_buffers_[i].get();
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
}
}
}
@ -470,10 +468,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (MustAliasOutput(input_output_alias, output_num)) {
DCHECK(output.buffer({output_num}).is_null())
<< "Expected output buffer to be aliased, but it is not nil.";
}
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,

View File

@ -165,7 +165,7 @@ class XlaComputationLaunchContext {
se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
std::deque<xla::ShapedBuffer> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
};

View File

@ -27,10 +27,6 @@ namespace tensorflow {
return xla_tensor;
}
/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
return tensor.RefCountIsOne();
}
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
const Tensor& tensor) {
const XlaTensor* xla_tensor = FromTensor(&tensor);

View File

@ -39,8 +39,6 @@ class XlaTensor {
// fails.
static XlaTensor* FromTensor(const Tensor* tensor);
static bool RefCountIsOne(const Tensor& tensor);
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
// which case the returned value is shaped_buffer()->root_buffer(), or a
// normal Tensor in which case the returned value is

View File

@ -30,7 +30,7 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"],
deps = [
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -42,7 +42,7 @@ cc_library(
":init_mlir",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
@ -86,7 +86,7 @@ cc_library(
hdrs = ["init_mlir.h"],
deps = [
"//tensorflow/core:lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -102,8 +102,9 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,
@ -154,7 +155,7 @@ tf_cc_binary(
"//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",

View File

@ -216,16 +216,16 @@ cc_library(
"ir/tfl_ops.h",
"transforms/passes.h",
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
@ -253,7 +253,26 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "tftext_utils",
srcs = [
"utils/tftext_utils.cc",
],
hdrs = [
"utils/tftext_utils.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -270,7 +289,7 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
@ -285,7 +304,7 @@ tf_cc_test(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -320,6 +339,7 @@ cc_library(
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
":tftext_utils",
":validators",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
@ -337,7 +357,7 @@ cc_library(
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -363,7 +383,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -396,7 +416,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -421,7 +441,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -474,8 +494,8 @@ tf_native_cc_binary(
"converter_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen",
],
)
@ -521,8 +541,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:analysis",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
@ -599,7 +619,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -633,7 +653,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -693,7 +713,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps",
@ -723,7 +743,7 @@ cc_library(
"tf_tfl_translate_cl.h",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
alwayslink = 1,
)
@ -735,7 +755,7 @@ cc_library(
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -760,7 +780,7 @@ tf_cc_binary(
":tf_tfl_translate_cl_options",
":tf_to_tfl_flatbuffer",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
@ -785,7 +805,7 @@ tf_cc_binary(
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
@ -854,7 +874,7 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@ -874,6 +894,6 @@ cc_library(
"//tensorflow/lite/experimental/mlir:__subpackages__",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)

View File

@ -32,7 +32,6 @@ struct PassConfig {
lower_tensor_list_ops(false),
trim_functions_whitelist({}),
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
@ -49,13 +48,8 @@ struct PassConfig {
llvm::ArrayRef<std::string> trim_functions_whitelist;
// All information about quantization.
QuantizationSpecs quant_specs;
// If `skip_control_dialect` is true, TF executor dialect is not converted to
// TF control dialect prior to legalization to TF Lite.
// TODO(b/142911013): Remove flag once control dialect is removed.
bool skip_control_dialect;
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
// are formed by grouping consecutive ops of the same device, under a
// `tf_device.launch` op.
// If `form_clusters` is true , clusters are formed by grouping consecutive
// ops of the same device, under a `tf_device.launch` op.
bool form_clusters;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.

View File

@ -525,11 +525,16 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto *val = trait.getDef().getValue("tflRuntimePredicate");
if (!val) continue;
auto desc = trait.getDef().getValueAsString("tflRuntimeDescription");
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" return ::mlir::LogicalResult::Failure;\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
" if (failure_on_operand_type_mismatch) {\n"
" return top.emitOpError(\"failed to verify that $1\");\n"
" } else {\n"
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
}
os << " return top.verify();\n}\n";
}

View File

@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
std::string node_def_str;
if (!node_def.SerializeToString(&node_def_str)) {
return emitError(loc, "failed to serialize tensorflow node_def"),
llvm::None;
}
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
return builder_.CreateVector(flex_builder->GetBuffer());
}
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
size_t map_start = flex_builder->StartMap();
for (const auto& pair : node_def.attr()) {
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
std::sort(attrs.begin(), attrs.end(),
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
for (const Item& pair : attrs) {
const char* key = pair.first.c_str();
const auto& attr = pair.second;
const ::tensorflow::AttrValue& attr = pair.second;
switch (attr.value_case()) {
case ::tensorflow::AttrValue::kS:
flex_builder->String(key, attr.s());

View File

@ -424,6 +424,10 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
const std::vector<uint8_t>& buffer,
OpBuilder builder, Location loc) {
if (buffer.empty()) {
return errors::InvalidArgument("Constant's buffer may not be empty");
}
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
@ -695,8 +699,6 @@ StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
}
}
@ -801,9 +803,17 @@ StatusOr<FuncOp> ConvertSubgraph(
}
for (auto output : func_outputs) {
bool is_constant = !is_op_output[output];
const bool is_func_input = input_index_set.contains(output);
bool is_constant = !is_op_output[output] && !is_func_input;
// There are 2 cases tensor is scalar when it doesn't have a shape in
// flatbuffer:
// 1. `is_constant` = true, means this tensor is created from a constant op.
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
// tensor is function input and function input type is a scalar tensor.
const bool shapeless_is_scalar =
is_constant || (is_func_input && is_entry_point);
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
/*shapeless_are_scalars=*/is_constant,
shapeless_is_scalar,
/*is_constant=*/is_constant);
if (!type_or_err.ok()) {
emitError(func_loc, "error reading return types")
@ -858,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
subgraph, &builder, "outputs", func_outputs));
}
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;

View File

@ -46,28 +46,183 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
// Returns true when the given two types have the same shape or broadcastable
// shape within the given rank. If any given shapes are non-static, this method
// returns true.
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
int max_bcast_rank) {
// Ignore shape checking on the non-static shapes for model compatibility.
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
// Returns true when the given operand arguments have the same shape or
// broadcastable shape within the given rank. If any given shapes are
// non-static and maximum rank is within the given rank, this method returns
// true.
bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
if (indices.empty()) return true;
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
return true;
// First, it checks there are any inputs that has unknown rank.
bool has_unknown_shape_input = false;
bool has_same_shape = true;
bool reach_first_known_shape = false;
int64_t max_rank = -1;
ArrayRef<int64_t> pivot_shape;
SmallVector<int64_t, 4> current_shape;
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
rhs_shaped_type.getShape(),
result_shape)) {
return false;
for (unsigned index : indices) {
ShapedType shaped_type =
op->getOperand(index).getType().dyn_cast<ShapedType>();
if (!shaped_type || !shaped_type.hasRank()) {
// Marks that we have an unknown rank input.
has_unknown_shape_input = true;
continue;
}
max_rank = std::max(max_rank, shaped_type.getRank());
if (!shaped_type.hasStaticShape()) {
// Marks that we have an unknown shape input.
has_unknown_shape_input = true;
continue;
}
ArrayRef<int64_t> shape = shaped_type.getShape();
if (!reach_first_known_shape) {
pivot_shape = shape;
current_shape.assign(shape.begin(), shape.end());
reach_first_known_shape = true;
continue;
}
if (!pivot_shape.equals(shape)) {
has_same_shape = false;
}
// Checks if all the inputs are broadcastable since they have not all the
// same shapes.
if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
result_shape)) {
return false;
}
current_shape = result_shape;
}
return lhs_shaped_type.getRank() <= max_bcast_rank &&
rhs_shaped_type.getRank() <= max_bcast_rank;
// It will treat the unknown shape inputs as acceptable inputs for model
// compatibility unless there is an known rank that is bigger than the allowed
// broadcast maximum rank.
if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
// If all the shape is known and same, CPU kernels are able to handle inputs
// regardless of dimension size.
return has_same_shape || max_rank <= max_bcast_rank;
}
// Return true when the given element_type is QI8.
bool IsQI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
quantized_type.isSigned();
}
// Return true when the given element_type is QUI8.
bool IsQUI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
!quantized_type.isSigned();
}
// Return true when the given element_type is QI16.
bool IsQI16Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 16 &&
quantized_type.isSigned();
}
// Return true when the given element_type is I32.
bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) ||
IsQUI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
// Allows QI16 output when operands have the same shape.
if (IsQI16Type(element_type)) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return false;
}
// Return true if the given Sub operation has the CPU kernel supported shapes.
bool VerifySubOpShapeConstraints(SubOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows QI8 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsQI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
}
// Return true if the given Mul operation has the CPU kernel supported shapes.
bool VerifyMulOpShapeConstraints(MulOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
// output type is not QI16. If the output type is Q16, allows onlt the same
// shape operands.
if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows F32 output when the operands have valid shapes, which are
// broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
}
//===----------------------------------------------------------------------===//
@ -1882,7 +2037,7 @@ static LogicalResult Verify(TransposeConvOp op) {
auto expected_output_type =
RankedTensorType::get(output_shape, output_type.getElementType());
if (output_type != expected_output_type) {
if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type));
}
@ -1966,9 +2121,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
}
static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x().getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y().getType().cast<ShapedType>();
auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError(
@ -2004,7 +2159,8 @@ static LogicalResult Verify(TransposeOp op) {
}
auto expected_output_type =
RankedTensorType::get(transposed_shape, input_type.getElementType());
if (output_type != expected_output_type) {
if (failed(
mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type));
}

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -56,7 +56,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -85,7 +85,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",

View File

@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_STRING;
case toco::IODataType::BOOL:
return DT_BOOL;
case toco::IODataType::COMPLEX64:
return DT_COMPLEX64;
default:
return DT_INVALID;
}
@ -175,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
@ -209,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
if (flag.has_mean_value() && flag.has_std_value()) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(),
flag.std_value(), inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
} else {
node_mins->push_back(llvm::None);
node_maxs->push_back(llvm::None);
}
}
}
@ -252,7 +258,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
return errors::InvalidArgument("Failed to open file in ", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));

View File

@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
// Populate quantization specs (or not) given user specified ranges for each
// input arrays.
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs);
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs);
// Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well.

View File

@ -3,6 +3,10 @@ load(
"//tensorflow/core/platform:build_config.bzl",
"tf_proto_library",
)
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
package(
default_visibility = [
@ -23,6 +27,7 @@ package_group(
exports_files([
"quantization_traits.h",
"quantization_config.h",
"quantization_utils.h",
])
filegroup(
@ -34,6 +39,25 @@ filegroup(
],
)
gentbl(
name = "quantization_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"quantization_interface.h.inc",
),
(
"-gen-op-interface-defs",
"quantization_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "quantization.td",
td_srcs = [
":quantization_td_files",
],
)
tf_proto_library(
name = "quantization_info_proto",
srcs = [
@ -56,7 +80,7 @@ cc_library(
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -71,16 +95,18 @@ cc_library(
name = "quantization_lib",
srcs = [
"quantization_driver.cc",
"quantization_interface.cc.inc",
"quantization_utils.cc",
],
hdrs = [
"quantization_interface.h.inc",
"quantization_traits.h",
"quantization_utils.h",
],
deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -99,7 +125,7 @@ cc_library(
deps = [
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -109,8 +135,8 @@ tf_native_cc_binary(
"tools/op_quant_spec_getters_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen",
],
)
@ -131,7 +157,7 @@ cc_library(
deps = [
":numerical_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
@ -146,7 +172,7 @@ cc_library(
":device_target",
":quantization_lib",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",

View File

@ -36,7 +36,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
@ -49,14 +49,16 @@ cc_library(
],
hdrs = [
"tfl_to_std.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_utils.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
@ -71,7 +73,7 @@ tf_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
],
)

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
namespace mlir {
namespace TFL {
@ -47,12 +48,18 @@ void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
dq.arg());
dq.getResult().replaceAllUsesWith(dcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
dcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
dq.erase();
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
auto out_type = q.getResult().getType();
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
TypeAttr::get(out_type));
q.getResult().replaceAllUsesWith(qcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
qcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
q.erase();
}
});

View File

@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>;
// https://www.tensorflow.org/lite/performance/quantization_spec
//===----------------------------------------------------------------------===//
// TODO(b/157870442): replace all FixedResultScale trait
def FixedOutputRangeInterface : OpInterface<
"FixedOutputRangeInterface"> {
let description = [{
Interface for defining the fixed output range.
}];
let methods = [
InterfaceMethod<
[{Returns the fixed output range.}],
"UniformQuantizedType", "GetFixedOutputRange",
(ins "bool":$sign, "int":$bit_width)
>,
];
}
// Specify this trait if the op has a fixed output value range.
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;

View File

@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
absl::string_view inference_type,
QuantizationSpecs* quant_specs) {
std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
std::vector<double> node_mins;
std::vector<llvm::Optional<double>> node_mins;
if (!min_values.empty()) {
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
for (int i = 0; i < node_mins_str.size(); i++) {
@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
}
}
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_maxs;
if (!max_values.empty()) {
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
for (int i = 0; i < node_maxs_str.size(); i++) {
@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
quant_specs);
}
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs) {
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
quant_specs->inference_type = inference_type;
// If min/max are not specified, just return;

View File

@ -69,7 +69,8 @@ struct QuantizationSpecs {
// arguments. They are only used when `weight_quantization` is set to false,
// and the model is required to have quantization parameters, either from
// quantization aware training or calibration, for the remaining tensors.
std::vector<std::pair<double, double>> input_ranges;
std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
input_ranges;
// The default ranges can be used when a tensor doesn't have quantization
// parameters and couldn't be quantized. Used only for latency tests.
@ -130,11 +131,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
// Gets the quantization specification for input arrays. The array names are not
// stored in the spec, and will be matched by position. The min/max will be
// ignored if the inference_type isn't a quantized type. Returns true if failed.
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs);
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
} // namespace TFL
} // namespace mlir

View File

@ -494,6 +494,13 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
auto dequantize = builder_.create<quant::DequantizeCastOp>(
loc, expressed_type, quantize.getResult());
// This attribute is set to distinguish the quantize ops being added by the
// quantization pass. These ops can be removed without losing original
// program accuracy.
// TODO(fengliuai): make the attribute being part of op definition.
quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value.replaceAllUsesWith(dequantize);

View File

@ -21,13 +21,18 @@ limitations under the License.
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace OpTrait {
namespace quant {
using QuantizedType = mlir::quant::QuantizedType;
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
namespace mlir {
// This includes the interface class definition. It couldn't be in a namespace
// because the table gen doesn't emit the namespace when it is used.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc"
namespace OpTrait {
namespace quant {
// The base class that all the quantization related OpTrait implements.
template <typename ConcreteType, template <typename> class TraitType>
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
@ -436,6 +437,16 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
llvm::SmallVector<quant::StatisticsOp, 16> all_stats_ops;
llvm::DenseSet<Operation*> redundant_stats_ops;
// Step 0: remove the quant::StatisticsOp which are used by the tfl.quantize
// op in case it overrides the information from training FakeQuant ops.
func.walk([&](quant::QuantizeCastOp q) {
auto input_op = q.arg().getDefiningOp();
if (auto stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(input_op)) {
q.setOperand(stats.arg());
if (stats.use_empty()) stats.erase();
}
});
// Step 1: forward pass: propagate any value scales which are not produces
// by `SameOperandsAndResultsScale`. Additionally, remove the value scales
// which are produced by the `restricted_output_params`.

View File

@ -22,6 +22,8 @@ limitations under the License.
#include <unordered_map>
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -35,11 +37,17 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
namespace quant {
// A unit attribute can be attached to the quantize/dequantize ops which are
// added by the quantization passes. These ops can be removed erased without
// losing accuracy.
constexpr char kVolatileOpAttrName[] = "volatile";
using QuantParams = quant::QuantizedType;
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
@ -363,6 +371,55 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
}
};
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
template <typename RQ>
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
explicit FoldTrivalRequantizeOp(MLIRContext* context)
: OpRewritePattern<RQ>(context, 1) {}
LogicalResult matchAndRewrite(RQ op,
PatternRewriter& rewriter) const override {
Value pre_quantized = op.input();
auto pre_quantized_type =
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
if (!pre_quantized_type) return failure();
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (llvm::isa<FixedOutputRangeInterface>(def) ||
def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
llvm::SmallVector<Type, 4> new_output_types;
for (auto result : def->getResults()) {
result.getUsers().begin()->dump();
op.dump();
if (result.hasOneUse() && *result.getUsers().begin() == op) {
new_output_types.push_back(op.qtype());
} else {
new_output_types.push_back(result.getType());
}
}
// Remove this rescale op.
rewriter.replaceOp(op, {pre_quantized});
// Replace the output scale of the preceding op.
rewriter.setInsertionPointAfter(def);
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
def->getOperands(), new_output_types,
def->getAttrs());
Operation* new_op = rewriter.createOperation(new_state);
rewriter.replaceOp(def, new_op->getResults());
return success();
}
};
// Given a quantized type `input`, magnifying its scales by the factor stored in
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
// dimension size of `input` or isn't floating-point, nullptr will be returned.

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s
// CHECK-LABEL: import_stats_skip

View File

@ -32,7 +32,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s --dump-input-on-failure
// RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s
// Checks that tfl.reshape should be removed if its output's only user is
// another tfl.reshape
@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
return %1 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacent
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE]]
}
// Checks that tfl.reshape should be removed if its output has more than one
@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32>
return %3 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %2 = addf %0, %1
// CHECK: return %2
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]]
// CHECK: return %[[RESULT]]
}
// Checks that tfl.reshape should be kept if its output has more than one
@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %cst_0 = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %0, %1
// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]]
}
// Checks that tfl.reshape should be removed if its output type is the same

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: @add_float
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<9> : tensor<i32>
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<7> : tensor<i32>
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32>
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_int
@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape
@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>,
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_splat_float
@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_float
@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_same_shape
@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @rank
func @rank() -> tensor<1xi32> {
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
// CHECK-LABEL: @rank_input_known_rank
func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
// CHECK-LABEL: @pseudo_const
func @pseudo_const() -> tensor<i32> {
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: return %[[CST]]
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
return %0 : tensor<i32>
}
@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> {
%cst_1 = constant dense<4> : tensor<i32>
%cst_2 = constant dense<1> : tensor<i32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> {
%cst_1 = constant dense<4.0> : tensor<f32>
%cst_2 = constant dense<1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> {
%cst_1 = constant dense<-4.0> : tensor<f32>
%cst_2 = constant dense<-1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> {
%cst_1 = constant dense<7.0> : tensor<f32>
%cst_2 = constant dense<1.5> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[0, 1]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> {
%cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
%cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
return %0 : tensor<4x2x3xi32>
}
@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> {
%87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %87 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic
@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> {
return %2 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concat_2_tensors_1_empty
@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> {
%3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32>
return %3 : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return [[cst]] : tensor<2xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return %[[CST]] : tensor<2xi32>
}
// CHECK-LABEL: @concat_3_tensors_1_empty
@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> {
%3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
// CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %0 : tensor<?xi32>
}
@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32>
return %0 : tensor<2x2x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsMiddleDim
@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32>
return %0 : tensor<1x4x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsLastDim
@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32>
return %0 : tensor<1x2x6xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_different_rank
@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
return %0 : tensor<1x2x2xf32>
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %[[CST]]
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
# Add two tensor<4xi32> inputs and return the result

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s --dump-input-on-failure
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - | flatbuffer_to_string - | FileCheck %s
node {

Some files were not shown because too many files have changed in this diff Show More