Merge branch 'master' into IFX/PR_tfl_converter_QAT_fixes

# Conflicts:
#	tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
This commit is contained in:
stevensa 2020-05-27 15:26:47 +02:00
commit a0b233ac41
2684 changed files with 171436 additions and 66455 deletions

View File

@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -163,6 +168,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1
build:dbg --config=opt -c dbg
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
build:dbg --copt -DDEBUG_BUILD
build:tensorrt --action_env TF_NEED_TENSORRT=1
@ -233,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
# Enable using platform specific build settings
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
build --enable_platform_specific_config
build:android --noenable_platform_specific_config
build:ios --noenable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:android --copt=-w
build:ios --copt=-w
build:linux --copt=-w
build:macos --copt=-w
build:windows --copt=/w
@ -256,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode.
build:android --cxxopt=-std=c++14
build:android --host_cxxopt=-std=c++14
build:ios --cxxopt=-std=c++14
build:ios --host_cxxopt=-std=c++14
build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14
@ -356,9 +372,10 @@ build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
@ -396,17 +413,21 @@ build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_P
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true
build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc

87
.github/bot_config.yml vendored Normal file
View File

@ -0,0 +1,87 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees
assignees:
- amahendrakar
- ravikyram
- Saduf2019
# A list of assignees for compiler folder
compiler_assignees:
- joker-eph
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
* For TF-GPU - See point 1
* For TF-CPU - See point 2
-----------------------------------------------------------------------------------------------
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
Make sure you are using compatible TF and CUDA versions.
Please refer following TF version and CUDA version compatibility table.
| TF | CUDA |
| :-------------: | :-------------: |
| 2.1.0 - 2.2.0 | 10.1 |
| 1.13.1 - 2.0 | 10.0 |
| 1.5.0 - 1.12.0 | 9.0 |
* If you have above configuration and using _**Windows**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
-----------------------------------------------------------------------------------------------
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
* Try Google Colab to use TensorFlow.
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
* All you need is a good internet connection and you are all set.
* Try to build TF from sources by changing CPU optimization flags.
*Please let us know if this helps.*
windows_comment: >
From the stack trace it looks like you are hitting windows path length limit.
* Try to disable path length limit on Windows 10.
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
Please let us know if this helps.

View File

@ -103,17 +103,17 @@ open-source software development:
### Official Builds
Build Type | Status | Artifacts
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
Build Type | Status | Artifacts
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
### Community Supported Builds
@ -142,6 +142,7 @@ Build Type | Status
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)

View File

@ -1,3 +1,185 @@
# Release 2.3.0
## Breaking Changes
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
# Release 2.1.1
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
# Release 2.0.2
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 1.15.3
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 2.2.0
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated.
## Major Features and Improvements
* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable.
* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines.
* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md).
* `tf.distribute`:
* Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training.
* Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy`
* Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this.
* Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage.
* Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation.
* Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental.
* Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks.
* `tf.keras`:
* `Model.fit` major improvements:
* You can now use custom training logic with `Model.fit` by overriding `Model.train_step`.
* Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc)
* See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`.
* SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of
relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models.
This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to
manually call `Model._set_inputs` when using Custom Training Loops(CTLs).
* Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator.
This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the
`DataAdapter` doesn't need to call the Model is probably preferable.
* The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers)
* Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode.
* `tf.lite`:
* Enable TFLite experimental new converter by default.
* XLA
* XLA now builds and works on windows. All prebuilt packages come with XLA available.
* XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction
) with “compile or throw exception” semantics on CPU and GPU.
## Breaking Changes
* `tf.keras`:
* In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer.
* Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function.
* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`.
* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release.
* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`.
* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed.
## Known Caveats
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
## Bug Fixes and Other Changes
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options.
* TF Core:
* `tf.constant` always creates CPU tensors irrespective of the current device context.
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Improve error message when attempting to use `None` in data-dependent control flow.
* Add `RaggedTensor.numpy()`.
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
* Allow `batch_dims==rank(indices)` in `tf.gather`.
* Add support for bfloat16 in `tf.print`.
* `tf.distribute`:
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
* `tf.keras`:
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
* Allow `pathlib.Path` paths for loading models via Keras API.
* `tf.function`/AutoGraph:
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
* `tf.lite`:
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* TFLite's unpack op now supports boolean tensor inputs.
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
* Check for large TFLite tensors.
* Fix GPU delegate crash with C++17.
* Add 5D support to TFLite `strided_slice`.
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi
# Release 2.0.1
## Bug Fixes and Other Changes

View File

@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
It is possible to write models that are secure in a sense that they can safely
process untrusted inputs assuming there are no bugs. There are two main reasons
to not rely on this: first, it is easy to write models which must not be exposed
to not rely on this: First, it is easy to write models which must not be exposed
to untrusted inputs, and second, there are bugs in any software system of
sufficient complexity. Letting users control inputs could allow them to trigger
bugs either in TensorFlow or in dependent libraries.
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
vulnerability in TensorFlow (although it would be a vulnerability of this
hypothetical system).
As a general rule, it is incorrect behavior for Tensorflow to access memory it
As a general rule, it is incorrect behavior for TensorFlow to access memory it
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
such behaviors constitute a vulnerability.

View File

@ -114,6 +114,14 @@ http_archive(
],
)
http_archive(
name = "person_detect_data",
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
urls = [
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")

View File

@ -144,7 +144,7 @@ def write_to_bazelrc(line):
def write_action_env_to_bazelrc(var_name, var):
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
def run_shell(cmd, allow_non_zero=False, stderr=None):
@ -205,7 +205,7 @@ def setup_python(environ_cp):
# Get PYTHON_BIN_PATH, default is the current running python.
default_python_bin_path = sys.executable
ask_python_bin_path = ('Please specify the location of python. [Default is '
'%s]: ') % default_python_bin_path
'{}]: ').format(default_python_bin_path)
while True:
python_bin_path = get_from_env_or_user_or_default(environ_cp,
'PYTHON_BIN_PATH',
@ -215,9 +215,10 @@ def setup_python(environ_cp):
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
break
elif not os.path.exists(python_bin_path):
print('Invalid python path: %s cannot be found.' % python_bin_path)
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
else:
print('%s is not executable. Is it the python binary?' % python_bin_path)
print('{} is not executable. Is it the python binary?'.format(
python_bin_path))
environ_cp['PYTHON_BIN_PATH'] = ''
# Convert python path to Windows style before checking lib and version
@ -236,7 +237,7 @@ def setup_python(environ_cp):
default_python_lib_path = python_lib_paths[0]
python_lib_path = get_input(
'Please input the desired Python library path to use. '
'Default is [%s]\n' % python_lib_paths[0])
'Default is [{}]\n'.format(python_lib_paths[0]))
if not python_lib_path:
python_lib_path = default_python_lib_path
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
@ -252,7 +253,7 @@ def setup_python(environ_cp):
# Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# If choosen python_lib_path is from a path specified in the PYTHONPATH
@ -266,7 +267,7 @@ def setup_python(environ_cp):
with open(
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
'w') as f:
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
def reset_tf_configure_bazelrc():
@ -320,11 +321,12 @@ def get_var(environ_cp,
Raise the error to avoid infinitely looping.
"""
if not question:
question = 'Do you wish to build TensorFlow with %s support?' % query_item
question = 'Do you wish to build TensorFlow with {} support?'.format(
query_item)
if not yes_reply:
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
if not no_reply:
no_reply = 'No %s' % yes_reply
no_reply = 'No {}'.format(yes_reply)
yes_reply += '\n'
no_reply += '\n'
@ -368,7 +370,7 @@ def get_var(environ_cp,
print(no_reply)
var = False
else:
print('Invalid selection: %s' % user_input_origin)
print('Invalid selection: {}'.format(user_input_origin))
return var
@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version):
if which('bazel') is None:
print('Cannot find bazel. Please install bazel.')
sys.exit(0)
curr_version = run_shell(
['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
for line in curr_version.split('\n'):
if 'Build label: ' in line:
curr_version = line.split('Build label: ')[1]
break
stderr = open(os.devnull, 'wb')
curr_version = run_shell(['bazel', '--version'],
allow_non_zero = True,
stderr = stderr)
if curr_version.startswith('bazel '):
curr_version = curr_version.split('bazel ')[1]
min_version_int = convert_version_to_int(min_version)
curr_version_int = convert_version_to_int(curr_version)
@ -1366,8 +1368,13 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
try:
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
except subprocess.CalledProcessError as e:
print("Error checking bazel version: ", e.output.decode('UTF-8').strip())
raise e
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1385,7 +1392,6 @@ def main():
# Windows.
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
environ_cp['TF_NEED_MPI'] = '0'
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos():
environ_cp['TF_NEED_TENSORRT'] = '0'

View File

@ -517,18 +517,29 @@ package_group(
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
],
)
package_group(name = "ndarray_tensor_allow_list")
package_group(
name = "ndarray_tensor_allow_list",
packages = ["//learning/pathways/..."],
)
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
package_group(
name = "types_whitelist",
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(

View File

@ -58,6 +58,7 @@ filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
"tf_status_helper.h",
@ -84,7 +85,14 @@ tf_cuda_library(
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:chromiumos": [
":tf_attrtype",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:platform",
],
"//conditions:default": [
":tf_attrtype",
@ -174,7 +182,7 @@ tf_cuda_library(
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_status",
@ -208,10 +216,11 @@ tf_cuda_library(
],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
@ -224,12 +233,13 @@ cc_library(
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
deps = select({
deps = [
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tf_status_internal",
"//tensorflow/core:lib",
],
}),
@ -251,10 +261,15 @@ cc_library(
name = "tensor_interface",
hdrs = ["tensor_interface.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
}),
)
cc_library(
@ -264,7 +279,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -278,16 +293,17 @@ cc_library(
srcs = ["tf_tensor.cc"],
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = select({
deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -303,14 +319,15 @@ tf_cuda_library(
"tf_tensor_internal.h",
],
visibility = ["//tensorflow:internal"],
deps = select({
deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
@ -378,8 +395,14 @@ tf_cuda_library(
deps = [
":tf_status",
":tf_status_internal",
"//tensorflow/core:lib",
],
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
)
tf_cc_test(
@ -418,7 +441,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -449,7 +472,7 @@ tf_cuda_library(
] + select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
@ -476,7 +499,7 @@ tf_cuda_library(
":tf_status_helper",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -532,6 +555,7 @@ tf_cuda_cc_test(
"//conditions:default": [],
}),
tags = [
"no_windows", # TODO(b/155444728)
"noasan",
],
# We must ensure that the dependencies can be dynamically linked since

View File

@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
return ret;
}
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
TF_Status* status) {
auto* opts = TFE_NewContextOptions();
// Reduce GPU memory allocation, and set appropriate config options for TFE
// context.
auto* config = TF_CreateConfig(
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
10);
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
if (!status->status.ok()) {
CHECK(!config);
TFE_DeleteContextOptions(opts);
return nullptr;
}
auto* ctx = TFE_NewContextFromSession(opts, session, status);
TF_DeleteBuffer(config);
TFE_DeleteContextOptions(opts);
return ctx;
}
// TODO: retrieve the device string via TFE_ContextListDevices()
static const char DEFAULT_CPU_DEVICE[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
int tensor_id, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return nullptr;
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
shared_name.size());
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
// TODO: consider making this an unknown shape.
const int64_t* dims_ptr = nullptr;
int num_dims = 0;
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
/*num_values*/ 0, status);
if (!status->status.ok()) return nullptr;
int num_retvals = 1;
TFE_TensorHandle* queue = nullptr;
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
if (!status->status.ok()) return nullptr;
CHECK_EQ(num_retvals, 1);
return queue;
}
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
TF_Status* status) {
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return;
TFE_OpAddInput(op, queue, status);
if (!status->status.ok()) return;
TFE_OpAddInput(op, tensor, status);
if (!status->status.ok()) return;
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
TFE_OpSetAttrInt(op, "timeout_ms", -1);
int num_retvals = 0;
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
if (!status->status.ok()) return;
CHECK_EQ(num_retvals, 0);
}
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
TF_DataType inputType,
TFE_TensorHandle* queue,
TF_Status* status) {
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return nullptr;
TFE_OpAddInput(op, queue, status);
if (!status->status.ok()) return nullptr;
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
TFE_OpSetAttrInt(op, "timeout_ms", -1);
TFE_TensorHandle* ret;
int num_retvals = 1;
TFE_Execute(op, &ret, &num_retvals, status);
if (!status->status.ok()) return nullptr;
CHECK_EQ(num_retvals, 1);
return ret;
}
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
TF_DataType inputType,
TF_Status* status) {
assert(session);
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
return ret;
}
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
TF_DataType inputType,
TF_Status* status) {
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
return ret;
}
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
TFE_TensorHandle* tensor, TF_Status* status) {
assert(session);
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, inputType, queue, tensor, status);
}
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status) {
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, inputType, queue, tensor, status);
}
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
TFE_TensorHandle* tensor, TF_Status* status) {
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
}
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
TF_Status* status) {
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
const TF_DataType* values, int num_values) {
auto iter = builder->attr_names.insert(attr_name).first;
builder->Set(
(*iter).c_str(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values),
num_values));
}
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,

View File

@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
// Create a serialized tensorflow.ServerDef proto.
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
// TODO: remove this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
// Creates from `session` a new eager context to run a graph function or
// sends/recvs, so that these concurrent TFE executions can share (via
// `session` and its associated device mgr) the same set of fifo queue resource
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
// graph function execution can access the same fifo queue resource handles
// (associated with devices managed by the device manager, which can be obtained
// from `session`).
//
// TODO: Remove this function once we migrate away from using session.
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
TF_Session* session, TF_Status* status);
// TODO: Retire this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
TF_Session* session, int tensor_id, TF_DataType inputType,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
TF_Status* status);
// TODO: consider folding the 2 APIs below into the ones above.
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);

View File

@ -16,15 +16,18 @@ limitations under the License.
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
#define TENSORFLOW_C_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
inline const wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<const wrapper *>(i); \
}
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_

View File

@ -35,7 +35,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":context_interface",
@ -144,6 +144,24 @@ cc_library(
],
)
cc_library(
name = "c_api_unified_internal",
hdrs = [
"c_api_unified_experimental_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
@ -246,6 +264,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
@ -316,7 +335,9 @@ tf_cuda_cc_test(
],
extra_copts = tfe_xla_copts(),
tags = [
"guitar",
"noguitar", # TODO(b/155445984): flaky
#"guitar",
"notap", # TODO(b/156981931): flaky
"multi_gpu",
],
deps = [
@ -344,7 +365,43 @@ tf_cuda_cc_test(
srcs = [
"c_api_remote_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"noasan", # leaks gRPC server instances
"notsan", # b/157098283
],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_cluster_test",
size = "small",
srcs = [
"c_api_cluster_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
deps = [
":c_api",
":c_api_experimental",
@ -358,6 +415,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform:env",
"@com_google_absl//absl/strings",
],
)
@ -379,7 +437,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api",
@ -415,6 +473,8 @@ tf_cuda_library(
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@ -472,6 +532,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
@ -575,7 +636,6 @@ filegroup(
],
exclude = [
"c_api_experimental.cc",
"*c_api_tfrt*",
"*test*",
"*dlpack*",
],

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#ifdef PLATFORM_GOOGLE
#include "tensorflow/c/eager/c_api_tfrt.h"
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) {
}
#if !defined(IS_MOBILE_PLATFORM)
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
const tensorflow::ServerDef& server_def) {
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
return false;
}
return server_def.default_session_config().SerializeAsString() ==
context->session_options().config.SerializeAsString();
}
tensorflow::Status AddRemoteDevicesToMgr(
const std::vector<string>& added_remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
const tensorflow::DeviceMgr* device_mgr =
AreLocalDevicesCompatible(context, server_def)
? context->local_device_mgr()
: nullptr;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
server_def, {device_mgr}, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
@ -500,6 +514,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// For cluster update, use a status group to aggregate statuses from
// * adding and removing remote devices
// * creating remote contexts on newly added workers
// * updating remote contexts on existing workers
// * updating the master context
// Note that we should not return immediately on errors in the middle of these
// updates to prevent cluster from having inconsistent context views.
//
// Unused if `reset_context` is True.
tensorflow::StatusGroup sg;
// When updating an existing context, populate the following lists with:
// * added_workers: set(remote_workers) - set(curr_remote_workers)
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
@ -535,7 +560,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
&added_workers, &removed_workers,
&existing_workers);
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
sg.Update(GetReplacedFromExistingWorkers(
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) {
@ -559,11 +584,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
existing_workers.end());
}
}
LOG_AND_RETURN_IF_ERROR(
RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
added_workers, grpc_server->master_env()->worker_cache,
remote_device_mgr));
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
sg.Update(AddRemoteDevicesToMgr(added_workers,
grpc_server->master_env()->worker_cache,
remote_device_mgr));
}
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
@ -584,7 +608,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
// Initialize remote eager workers.
// TODO(b/138847548) Create remote eager contexts in async mode by default.
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
@ -596,7 +619,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// existing workers to also have the updated context_view_id, so
// we must set their context_view_id to the existing master's
// context_view_id + 1.
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
sg.Update(CreateRemoteContexts(
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
@ -606,10 +629,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
VLOG(1) << "Updating cluster with existing worker " << w;
}
}
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
ctx, existing_workers, added_workers, removed_workers, context_id,
context_view_id + 1, server_def, remote_eager_workers.get(),
base_request));
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
removed_workers, context_id,
context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
}
}
@ -645,13 +668,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// GrpcServer cannot be destroyed after it is started.
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
} else {
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
/*isolate_session_state=*/true));
LOG_AND_RETURN_IF_ERROR(
context->UpdateRemoteMaster(context_id, std::move(remote_eager_workers),
added_workers, removed_workers));
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
/*isolate_session_state=*/true));
sg.Update(context->UpdateRemoteMaster(context_id,
std::move(remote_eager_workers),
added_workers, removed_workers));
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
}
#undef LOG_AND_RETURN_IF_ERROR
@ -685,8 +708,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
status->status = tensorflow::Status::OK();
return tensorflow::wrap(new tfrt::ContextInterface());
tfrt::SmallVector<std::string, 4> op_handler_chains;
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
status->status = tfrt::ListOpHandlerChains(
opts->session_options.options, &op_handler_chains, &device_attributes);
if (!status->status.ok()) return nullptr;
return tensorflow::wrap(
new tfrt::ContextInterface(op_handler_chains, device_attributes));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
@ -713,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::GetDefaultCustomKernelCreator()));
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
TF_Session* sess, TF_Status* status) {
const tensorflow::DeviceMgr* device_mgr = nullptr;
status->status = sess->session->LocalDeviceManager(&device_mgr);
if (!status->status.ok()) return nullptr;
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator()));
}
void TFE_DeleteContext(TFE_Context* ctx) {
if (ctx == nullptr) {
return;
@ -747,9 +757,7 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
}
void TFE_ContextClearCaches(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->ClearCachesAndThreadExecutors();
tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
}
// Set server_def on the context, possibly updating it.
@ -887,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors();
status->status = tensorflow::unwrap(ctx)->AsyncWait();
#endif // !IS_MOBILE_PLATFORM
}
@ -912,7 +918,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
context->GetDevicePlacementPolicy());
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
@ -1460,20 +1466,21 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
} // namespace
void TFE_ContextStartStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->StartStep();
tensorflow::unwrap(ctx)->StartStep();
}
void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->EndStep();
tensorflow::unwrap(ctx)->EndStep();
}
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
return tensorflow::wrap(
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
@ -1485,8 +1492,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
TF_Status* status) {
tensorflow::NameAttrList name_and_attrs;
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(attrs->attributes->op_name());
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
status->status = MessageToBuffer(name_and_attrs, buf);
}
@ -1607,9 +1614,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
TFE_OpAttrs attributes(&op->Attrs());
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
&attributes, num_retvals, outputs.data(), &status, info_);
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface(

View File

@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
TF_Status* status);
// Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);

View File

@ -0,0 +1,499 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost", ":", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->at(task_index) =
tensorflow::strings::StrCat("localhost:", port);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < expected_values.size(); i++) {
EXPECT_EQ(expected_values[i], actual_values[i])
<< "Mismatch in expected values at (zero-based) index " << i;
}
}
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 =
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
TFE_DeleteTensorHandle(retval_task0);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
}
// Read the value of variable `var` and save it into `out_value`.
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
TFE_TensorHandle** out_value) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_Execute(op, out_value, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
// Update the server def with a new set of names (worker instead of
// localhost).
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
serialized = updated_server_def.SerializeAsString();
updated_server_def.set_task_index(1);
tensorflow::Status s = tensorflow::GrpcServer::Create(
updated_server_def, tensorflow::Env::Default(), &worker_server);
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Copying and executing on the new remote device works.
const char new_remote_device_name[] =
"/job:worker/replica:0/task:1/device:CPU:0";
const char new_local_device_name[] =
"/job:worker/replica:0/task:0/device:CPU:0";
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
h0_task0_new, ctx, new_remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0_new);
TFE_DeleteTensorHandle(h0_task1_new);
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
new_local_device_name);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteChangeServerDef) {
TestRemoteExecuteChangeServerDef(false);
}
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
TestRemoteExecuteChangeServerDef(true);
}
void TestRemoteExecuteUpdateServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteUpdateServerDef) {
TestRemoteExecuteUpdateServerDef(false);
}
TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
TestRemoteExecuteUpdateServerDef(true);
}
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
TFE_TensorHandle* value_handle = nullptr;
ReadVariable(ctx, var_handle1, &value_handle);
CheckTFE_TensorHandleHasFloats(value_handle, {2});
TFE_DeleteTensorHandle(value_handle);
// Start a new worker to replace task:1
ReplaceTaskInServerDef(&server_def, 1);
server_def.set_task_index(1);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
// Update server def to replace the remote device with the device info on the
// new worker (different incarnation ID).
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// The device of var_handle0 is local device which is the same before and
// after cluster update. Remove resource with valid device should succeed.
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle0, status);
TFE_OpSetDevice(op, dev0_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
// The device of var_handle1 is remote device, which was replaced during
// cluster update. Removing resource with invalid device should fail
// gracefully (i.e., with error status) instead of crashing with segfaults.
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle1, status);
TFE_OpSetDevice(op, dev1_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
TestRemoteExecuteUpdateServerDefResourceAccess(false);
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
TestRemoteExecuteUpdateServerDefResourceAccess(true);
}
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
// Adding a non-existent remote worker to cluster def. This should cause the
// UpdateServerDef call to fail.
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{2, tensorflow::strings::StrCat("localhost:", port)});
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Even after the prevoiusly failed cluster update, another update and op
// execution should work fine as long as the provided server_def is valid.
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
tensorflow::unsetenv("GRPC_FAIL_FAST");
}
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
TestRemoteExecuteUpdateServerDefWithFailures(false);
}
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
TestRemoteExecuteUpdateServerDefWithFailures(true);
}
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
const string first_name =
keep_localhost_for_first_connect ? "localhost" : "abc";
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
tensorflow::Status status2;
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
// Rename local device
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev1_name =
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
// Another renaming of local device
const string second_name = "def";
server_def.set_job_name(second_name);
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
absl::StrCat(second_name, ":",
tensorflow::testing::PickUnusedPortOrDie());
serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle2, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteTensorHandle(var_handle2);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
tensorflow::unsetenv("GRPC_FAIL_FAST");
}
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
} // namespace

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/lib/monitoring/counter.h"
@ -638,3 +639,35 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
return tensorflow::wrap(
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
}
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
TFE_TensorHandle** handles,
int* num_handles,
TF_Status* status) {
std::vector<tensorflow::TensorHandle*> tensor_handles;
tensor_handles.reserve(*num_handles);
for (int i = 0; i < *num_handles; ++i) {
tensor_handles.push_back(
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::TensorHandle* handle = nullptr;
status->status = tensorflow::TensorHandle::CreatePackedHandle(
std::move(tensor_handles), context, &handle);
return tensorflow::wrap(handle);
}
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
}

View File

@ -431,6 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive.
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
@ -538,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
// Create a packed TensorHandle with the given list of TensorHandles.
// If `handles` are on the same device, assign the same device to the packed
// handle; if `handles` are on different deivces, assign a CompositeDevice to
// it.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
TF_Status* status);
// Configure soft device placement policy for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
// Configure device placement policy logging for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -19,11 +19,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
@ -168,7 +173,11 @@ string MatMulFunction() {
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -193,47 +202,62 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
@ -243,7 +267,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) {
@ -255,10 +279,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@ -273,12 +297,18 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
@ -293,22 +323,435 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true, false);
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true, true, false);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(true, true, true);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false, false);
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false, false);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(true, false, true);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/true);
}
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
@ -381,150 +824,4 @@ TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < expected_values.size(); i++) {
EXPECT_EQ(expected_values[i], actual_values[i])
<< "Mismatch in expected values at (zero-based) index " << i;
}
}
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 =
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
TFE_DeleteTensorHandle(retval_task0);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
// Update the server def with a new set of names (worker instead of
// localhost).
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
serialized = updated_server_def.SerializeAsString();
updated_server_def.set_task_index(1);
tensorflow::Status s = tensorflow::GrpcServer::Create(
updated_server_def, tensorflow::Env::Default(), &worker_server);
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Copying and executing on the new remote device works.
const char new_remote_device_name[] =
"/job:worker/replica:0/task:1/device:CPU:0";
const char new_local_device_name[] =
"/job:worker/replica:0/task:0/device:CPU:0";
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
h0_task0_new, ctx, new_remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0_new);
TFE_DeleteTensorHandle(h0_task1_new);
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
new_local_device_name);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteChangeServerDef) {
TestRemoteExecuteChangeServerDef(false);
}
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
TestRemoteExecuteChangeServerDef(true);
}
} // namespace

View File

@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
}
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
TF_Status* status) {
// Create the variable handle.
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op, &var_handle, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(1, num_retvals);
// Assign 'value' to it.
op = TFE_NewOp(ctx, "AssignVariableOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
value_handle(TFE_NewTensorHandle(t.get(), status),
TFE_DeleteTensorHandle);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(0, num_retvals);
return var_handle;
}
TEST(CAPI, Variables) {
// Variables use resource handles, so this is really a test for resource
// tensor handling.
@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
@ -1248,6 +1203,8 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
h = nullptr;
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
tensorflow::testing::StopTiming();
TFE_DeleteOp(op);
@ -1591,15 +1548,11 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
// There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
TFE_OpAddAttrs(copy_op, attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
@ -1631,14 +1584,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
// There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,

View File

@ -133,6 +133,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
return th;
}
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
const tensorflow::string& device_name) {
TF_Status* status = TF_NewStatus();
// Create the variable handle.
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (!device_name.empty()) {
TFE_OpSetDevice(op, device_name.c_str(), status);
}
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op, &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(1, num_retvals);
// Assign 'value' to it.
op = TFE_NewOp(ctx, "AssignVariableOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
value_handle(TFE_NewTensorHandle(t.get(), status),
TFE_DeleteTensorHandle);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(0, num_retvals);
TF_DeleteStatus(status);
return var_handle;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();

View File

@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a variable handle referring to a variable with the given initial value
// on the given device.
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
const tensorflow::string& device_name = "");
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -26,6 +28,51 @@ using tensorflow::string;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
namespace tensorflow {
namespace internal {
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
static FactoriesMap& GetFactories() {
static FactoriesMap* factories = new FactoriesMap;
return *factories;
}
static const char* default_factory = "<unset>";
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) ||
(GetFactories()[name] == factory) &&
"Duplicate tracing factory registration");
GetFactories()[name] = factory;
}
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '",
default_factory, "' (available: ");
// Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted;
for (const auto& factory : GetFactories())
factories_sorted.insert(factory.first);
const char* comma = "";
for (const string& factory : factories_sorted) {
msg += comma + factory;
comma = ", ";
}
msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
} // end namespace internal
} // end namespace tensorflow
// =============================================================================
// Public C API entry points
//
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
//
// =============================================================================
void TF_SetTracingImplementation(const char* name) {
tensorflow::internal::SetDefaultTracingEngine(name);
}
// Creates a new TensorFlow function, it is an execution context attached to a
// given tracing context.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
}
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList* outputs, TF_Status* s) {
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
TF_DeleteExecutionContext(ctx);
return func;
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
return wrap(unwrap(func)->AddParameter(dtype, s));
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return wrap(unwrap(o)->outputs[i]);
}
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status* s) {
unwrap(o)->outputs.push_back(unwrap(tensor));
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {

View File

@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
// Creates a context for tracing the execution of operations into a function.
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
// This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined
// here.
void TF_SetTracingImplementation(const char* name);
// Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing
// tracing, a function can be obtained by TF_FinalizeFunction.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
// Creates a context for eager execution of operations.
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Add a new parameter to a TensorFlow Function.
// TODO(aminim): what about shape?
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
// an operation.
// It just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
// TODO(aminim): the description above isn't clear with respect to
// TF_OutputListNumOutputs and the current eager implementation which requires
// the number of outputs to be set by the client.
// an operation, or provided to create a function.
// When executing an operation in an eager context, the expected number of
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
// Prepare tracing to the expected number of output for an operation.
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
// Return the number of outputs in the list.
int TF_OutputListNumOutputs(TF_OutputList* o);
// Return the `i`th output in the list.
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Append a tensor at the end of the output list, growing its size by one.
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph. The output tensors are
@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_ExecutionContext* ctx, TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The returned TF_GraphToFunction must be deleted by the client.
// context. The provided `ctx` is consumed by this API call and deleted.
// The returned TF_AbstractFunction must be deleted by the client,
// TODO(aminim): clarify the contract on the state of the context after this
// call.
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList*, TF_Status*);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);

View File

@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
}
}
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't add function parameter on an eager context.");
return nullptr;
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't use finalize function on an eager context.");
return nullptr;
}
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s);
if (!func) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
static constexpr AbstractFunctionKind kKind = kGraphFunc;
};
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
// adding them to the graph.
// GraphContext wraps a TF_Graph modeling a single function and manages the
// "execution" of operation, i.e. adding them to the function.
class GraphContext : public ExecutionContext {
public:
GraphContext()
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
explicit GraphContext(const char* name)
: ExecutionContext(kKind),
graph_(new TF_Graph(), TF_DeleteGraph),
name_(name) {}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext {
return;
}
auto* tf_opdesc = graph_op->op_.release();
if (tf_opdesc == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
return;
}
for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) {
@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext {
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const GraphTensor* inputs, int num_outputs,
const GraphTensor* outputs, TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_OperationDescription* opdesc =
TF_NewOperation(graph_.get(), "Placeholder",
absl::StrCat("_input_", inputs_.size()).c_str());
TF_SetAttrType(opdesc, "dtype", dtype);
auto* operation = TF_FinishOperation(opdesc, s);
if (!s->status.ok()) return nullptr;
inputs_.push_back(TF_Output{operation, 0});
return new GraphTensor(inputs_.back(), this);
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
std::unique_ptr<GraphFunction> func(new GraphFunction);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = inputs[i].output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = outputs[i].output;
graph_outputs.reserve(outputs->outputs.size());
for (AbstractTensor* abstract_output : outputs->outputs) {
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
if (!output) {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
return nullptr;
}
graph_outputs.push_back(output->output);
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
func->func = TF_GraphToFunction(
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
return func.release();
}
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext {
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_;
const char* name_;
};
// Helper that converts the graph currently held in the context into a function.
static AbstractFunction* ExecutionContextToFunction(
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const AbstractTensor* inputs, int num_outputs,
const AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
if (!graph_inputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
return nullptr;
}
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
if (!graph_outputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
return nullptr;
}
GraphFunction* func = new GraphFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
num_outputs, graph_outputs, status);
return func;
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
return new GraphContext(name);
}
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef");
return true;
}();
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Graph API.
// =============================================================================
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
return wrap(new tensorflow::internal::GraphContext());
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
unwrap(inputs), num_outputs,
unwrap(outputs), status));
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
@ -57,7 +58,7 @@ T* dyncast(S source) {
// GraphContext and vice-versa).
class AbstractTensor {
protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public:
@ -100,7 +101,7 @@ class AbstractFunction {
// on a given context, with the same or different input tensors.
class AbstractOp {
protected:
enum AbstractOpKind { kGraphOp, kEagerOp };
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public:
@ -128,7 +129,7 @@ class AbstractOp {
// eager implementation or to a graph implementation.
struct ExecutionContext {
protected:
enum ExecutionContextKind { kGraphContext, kEagerContext };
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public:
@ -148,6 +149,17 @@ struct ExecutionContext {
// Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0;
// Add a function parameter and return the corresponding tensor.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
// Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
@ -156,6 +168,11 @@ struct ExecutionContext {
const ExecutionContextKind k;
};
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \

View File

@ -29,7 +29,12 @@ using tensorflow::string;
namespace tensorflow {
namespace {
TEST(UnifedCAPI, TestBasicEager) {
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
protected:
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
};
TEST_P(UnifiedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -81,33 +86,18 @@ TEST(UnifedCAPI, TestBasicEager) {
TF_DeleteExecutionContext(ctx);
}
TEST(UnifedCAPI, TestBasicGraph) {
TEST_P(UnifiedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
// Start a new function / execution context.
string fn_name = "double";
TF_ExecutionContext* graph_ctx =
TF_CreateFunction(fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
@ -123,16 +113,13 @@ TEST(UnifedCAPI, TestBasicGraph) {
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
// Clean up operation and inputs.
TF_DeleteAbstractOp(add_op);
string fn_name = "double";
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractTensor(output_t);
TF_AbstractFunction* func =
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -173,18 +160,161 @@ TEST(UnifedCAPI, TestBasicGraph) {
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TF_DeleteAbstractTensor(final_result);
TF_DeleteTensor(f_t);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
}
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Status* s = status.get();
// Start a new function / execution context.
string fn_name = "two_adds";
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
TF_AbstractTensor* add_output1;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output1 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Same with a second "Add" computing `arg1 + arg1`.
TF_AbstractTensor* add_output2;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output2 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Finalize the function by providing the returned values.
TF_AbstractFunction* func;
{
// We want to return the output of both add operations, create a new list
// and populate it.
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListPushBack(func_outputs, add_output1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, add_output2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteOutputList(func_outputs);
}
/**
* We traced so far this function:
*
* def two_adds(a, b):
* my_add1 = a + b
* my_add2 = b + b
* return my_add1, my_add2
*
* Now we will execute this function with an eager context:
*
* output1, output2 = two_adds(2.0, 3.0)
*
* and check that we got 5.0 and 6.0 as results.
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build two abstract input tensors as function arguments.
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
}
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(func_outputs, 2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
eager_execution_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
float results[2];
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
TF_DeleteTensor(f_t);
}
ASSERT_EQ(results[0], 5.0);
ASSERT_EQ(results[1], 6.0);
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TF_DeleteAbstractTensor(result);
}
TF_DeleteOutputList(func_outputs);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteAbstractFunction(func);
}
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -192,18 +322,15 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
ASSERT_EQ(nullptr, func);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteExecutionContext(ctx);
}
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -221,10 +348,10 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -242,7 +369,7 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -272,7 +399,8 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
@ -288,10 +416,11 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -348,5 +477,8 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx);
}
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
} // namespace
} // namespace tensorflow

View File

@ -59,6 +59,20 @@ class AbstractContextInterface {
virtual AbstractTensorInterface* CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
// Create a tensor instance from the given data buffer and description.
// `memory_releaser` will be called on destruction, and it's responsible for
// cleaning up the underlying buffer. `convert_string` indicates whether it
// has to handle tstring conversion. Expected to be removed once tstring
// migration is done.
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
const int64_t* dims,
int num_dims, void* data,
size_t len, bool convert_string,
MemoryReleaser memory_releaser,
void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) = 0;
@ -79,6 +93,17 @@ class AbstractContextInterface {
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -27,6 +27,7 @@ cc_library(
name = "parallel_device",
srcs = [":sources"],
hdrs = [":headers"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
@ -43,6 +44,7 @@ tf_cc_test(
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -52,3 +54,19 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(
name = "parallel_device_ops_srcs",
srcs = ["parallel_device_ops.cc"],
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
)
cc_library(
name = "parallel_device_ops",
srcs = [":parallel_device_ops_srcs"],
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -92,6 +92,10 @@ class ParallelDevice {
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
}
result.emplace(std::move(outputs));
return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
// TODO(allenl): Figure out if we need this op, and if so whether we should move
// it to core TF. Right now the eager C API does some checking of op
// registrations before calling into custom devices, but we may be able to avoid
// that.
REGISTER_OP("DeviceID")
.Output("device_id: int64")
.SetIsStateful()
.SetShapeFn(tensorflow::shape_inference::ScalarShape);

View File

@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
}
// Assert that `handle` is equal to `expected_value`.
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(expected_value,
*static_cast<float*>(TF_TensorData(value_zero.get())));
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 20.);
AssertScalarFloatEq(components[1].get(), 20.);
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 23.);
AssertScalarFloatEq(components[1].get(), 18.);
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device.
AssertScalarFloatEq(components[0].get(), 3.);
AssertScalarFloatEq(components[1].get(), 3.);
ExpectScalarEq<float>(components[0].get(), 3.);
ExpectScalarEq<float>(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices.
std::string first_device =
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
&second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(second_components[1].get(), 9.);
ExpectScalarEq<float>(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName(
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get());
AssertScalarFloatEq(first_components[0].get(), 3.);
AssertScalarFloatEq(first_components[1].get(), 6.);
ExpectScalarEq<float>(first_components[0].get(), 3.);
ExpectScalarEq<float>(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get());
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 3.);
AssertScalarFloatEq(result_components[1].get(), 3.);
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
}
void RegisterCollectiveMulFunction(TFE_Context* context,
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get());

View File

@ -15,33 +15,21 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#include <algorithm>
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h"
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
typedef struct TFE_OpAttrs TFE_OpAttrs;
typedef struct TFE_Context TFE_Context;
typedef struct TFE_Op TFE_Op;
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
: attributes(value) {}
const tensorflow::AttrBuilder* attributes;
};
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,

View File

@ -85,17 +85,36 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
const std::string test_name = tensorflow::str_util::StringReplace(
::testing::UnitTest::GetInstance()->current_test_info()->name(), "/",
"_", /*replace_all=*/true);
root_dir_ = tensorflow::io::JoinPath(
::testing::TempDir(),
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
if (!cloud_path_.empty()) {
// We have to join path for non-local filesystem manually to make sure
// that this test will run on Windows since `tensorflow::io::JoinPath`
// behaves differently on Windows. `tmp_dir` should be something like
// `path/to/tmp/dir/`. After joining path, we will have
// /path/to/tmp/dir/tf_fs_rng_name/`
root_dir_ = tensorflow::strings::StrCat(
"/", tmp_dir_,
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name), "/");
} else {
root_dir_ = tensorflow::io::JoinPath(
tmp_dir_,
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
}
if (!GetParam().empty()) {
root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_,
root_dir_);
}
env_ = Env::Default();
}
void SetUp() override {
if (mkdir(root_dir_.c_str(), 0755) != 0) {
int error_code = errno;
GTEST_SKIP() << "Cannot create working directory: "
<< tensorflow::IOError(root_dir_, error_code);
FileSystem* fs = nullptr;
Status s = env_->GetFileSystemForFile(root_dir_, &fs);
if (fs == nullptr || !s.ok())
GTEST_SKIP() << "No filesystem registered: " << s;
s = fs->CreateDir(root_dir_);
if (!s.ok()) {
GTEST_SKIP() << "Cannot create working directory: " << s;
}
}
@ -115,9 +134,10 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
std::string GetURIForPath(StringPiece path) {
const std::string translated_name =
tensorflow::io::JoinPath(root_dir_, path);
if (GetParam().empty()) return translated_name;
return tensorflow::strings::StrCat(GetParam(), "://", translated_name);
// We have already checked `GetParam().empty()` in
// `ModularFileSystemTest()`. root_dir_ should contain `GetParam() + "://"`
// if it isn't empty.
return translated_name;
}
// Converts absolute paths to paths relative to root_dir_.
@ -133,15 +153,28 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
rng_val_ = distribution(gen);
}
static void SetCloudPath(const std::string& cloud_path) {
cloud_path_ = cloud_path;
if (cloud_path_.back() == '/') cloud_path_.pop_back();
}
static void SetTmpDir(const std::string& tmp_dir) {
tmp_dir_ = tmp_dir.empty() ? ::testing::TempDir() : tmp_dir;
}
protected:
Env* env_;
private:
std::string root_dir_;
static int rng_val_;
static std::string cloud_path_;
static std::string tmp_dir_;
};
int ModularFileSystemTest::rng_val_;
std::string ModularFileSystemTest::cloud_path_;
std::string ModularFileSystemTest::tmp_dir_;
// As some of the implementations might be missing, the tests should still pass
// if the returned `Status` signals the unimplemented state.
@ -1729,6 +1762,20 @@ static bool GetURIScheme(const std::string& scheme) {
return true;
}
// This function is used for cloud filesystem
// `S3` and `GCS` require the `root_dir_` to have bucket name
// `HDFS` requires the `root_dir` to have namenode
// `root_dir_ = scheme + "://" cloud_path_ + root_dir_`
static bool SetCloudPath(const std::string& cloud_path_) {
ModularFileSystemTest::SetCloudPath(cloud_path_);
return true;
}
static bool SetTmpDir(const std::string& tmp_dir_) {
ModularFileSystemTest::SetTmpDir(tmp_dir_);
return true;
}
} // namespace
} // namespace tensorflow
@ -1741,7 +1788,12 @@ GTEST_API_ int main(int argc, char** argv) {
tensorflow::Flag("dso", tensorflow::LoadDSO, "",
"Path to shared object to load"),
tensorflow::Flag("scheme", tensorflow::GetURIScheme, "",
"URI scheme to test")};
"URI scheme to test"),
tensorflow::Flag("cloud_path", tensorflow::SetCloudPath, "",
"Path for cloud filesystem (namenode for hdfs, "
"bucketname for s3/gcs)"),
tensorflow::Flag("tmp_dir", tensorflow::SetTmpDir, "",
"Temporary directory to store test data.")};
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
std::cout << tensorflow::Flags::Usage(argv[0], flag_list);
return -1;

View File

@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory {
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def,
Status NewServer(const ServerDef& server_def, const Options& options,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,

View File

@ -31,9 +31,6 @@ cc_library(
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
],
copts = tf_copts(),
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
# so that we can depend on c/eager/c_api_unified_experimental.h.
features = ["-layering_check"],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
@ -41,6 +38,8 @@ cc_library(
":concrete_function_type",
":function_metadata",
":function_metadata_type",
":tensorhandle_list",
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
@ -156,10 +155,43 @@ cc_library(
"saved_model_api_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
],
)
cc_library(
name = "tensorhandle_list",
srcs = [
"tensorhandle_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
cc_library(
name = "tensorhandle_list_type",
hdrs = [
"tensorhandle_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:tensor_handle_interface",
],
)
tf_cc_test(
name = "saved_model_api_test",
size = "small",

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
extern "C" {
@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
&tensorflow::unwrap(func)->GetFunctionMetadata()));
}
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
// internal header, and implement this function.
return nullptr;
const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func) {
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
}
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {

View File

@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
void TF_DeleteSavedModel(TF_SavedModel* model) {
delete tensorflow::unwrap(model);
}
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
char* function_path,
const char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result);
tensorflow::unwrap(model)->GetFunction(function_path, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -79,10 +81,11 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
}
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status) {
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
&result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
return new TF_ConcreteFunctionList{
tensorflow::unwrap(model)->ListFunctions()};
}
} // end extern "C"

View File

@ -18,13 +18,18 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_SavedModel {
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
};
typedef struct TF_SavedModel TF_SavedModel;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include <stddef.h>
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
extern "C" {
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
return tensorflow::unwrap(list)->size();
}
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
// Internal structures used by the SavedModel C API. These are likely to
// change and should not be depended on.
typedef struct TF_TensorHandleList TF_TensorHandleList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(
std::vector<tensorflow::AbstractTensorHandleInterface*>,
TF_TensorHandleList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_

View File

@ -24,6 +24,7 @@ exports_files(
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
"tensorhandle_list.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
@ -39,6 +40,7 @@ cc_library(
":concrete_function_list",
":function_metadata",
":saved_model_api",
":tensorhandle_list",
],
)
@ -61,3 +63,8 @@ alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)
alias(
name = "tensorhandle_list",
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#ifdef __cplusplus
extern "C" {
@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
TF_ConcreteFunction* func);
// Returns a list of TensorHandles implicitly captured by this function.
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func);
// Returns a TFE_Op suitable for executing this function.

View File

@ -21,19 +21,27 @@ limitations under the License.
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
// Returns the size of `list`.
TF_CAPI_EXPORT size_t
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize(
TF_ConcreteFunctionList* list);
// Returns the `i`th TF_ConcreteFunction in the list.
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_ConcreteFunctionList* list, int i);
// Deletes `list`.
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList(
TF_ConcreteFunctionList* list);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -80,7 +80,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
// "conceptually" bound to `model`. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
TF_SavedModel* model, char* function_path, TF_Status* status);
TF_SavedModel* model, const char* function_path, TF_Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
@ -94,7 +94,7 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
// TF_ConcreteFunction instance. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status);
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.

View File

@ -0,0 +1,43 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_TensorHandleList TF_TensorHandleList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
const TF_TensorHandleList* list);
// Returns the `i`th TFE_TensorHandle in the list.
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
const TF_TensorHandleList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_

View File

@ -178,7 +178,7 @@ cc_library_with_android_deps(
name = "ops",
srcs = ["framework/ops.cc"],
hdrs = ["framework/ops.h"],
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@ -197,7 +197,7 @@ cc_library_with_android_deps(
"framework/scope_internal.h",
],
hdrs = ["framework/scope.h"],
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
common_deps = [
":ops",
],
@ -237,7 +237,7 @@ cc_library_with_android_deps(
name = "client_session",
srcs = ["client/client_session.cc"],
hdrs = ["client/client_session.h"],
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
common_deps = [
":ops",
":scope",
@ -275,7 +275,7 @@ cc_library_with_android_deps(
srcs = ["ops/const_op.cc"],
hdrs = ["ops/const_op.h"],
android_deps = [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
],
common_deps = [
":ops",
@ -304,7 +304,7 @@ cc_library_with_android_deps(
srcs = ["ops/while_loop.cc"],
hdrs = ["ops/while_loop.h"],
android_deps = [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
],
common_deps = [
":cc_ops",

View File

@ -0,0 +1,78 @@
# Experimental C++ APIs for TensorFlow.
# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability.
# Users are expected to compile against public c++ headers, and link against
# libtensorflow (https://www.tensorflow.org/install/lang_c).
# We aim to achieve ABI stability in new C++ APIs by only using types
# on the API surface that:
# 1. Have a header-only implementation
# 2. Are std:: types
# 3. Wrap an opaque C type
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "runtime",
hdrs = [
"runtime.h",
],
deps = [
":status",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)
cc_library(
name = "runtime_builder",
hdrs = [
"runtime_builder.h",
],
deps = [
":runtime",
":status",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)
cc_library(
name = "status",
hdrs = [
"status.h",
],
deps = [
"//tensorflow/c:tf_status",
],
)
cc_library(
name = "tensor",
hdrs = [
"tensor.h",
],
deps = [
":status",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_tensor",
],
)
cc_library(
name = "tensorhandle",
hdrs = [
"tensorhandle.h",
],
deps = [
":runtime",
":status",
":tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)

View File

@ -0,0 +1,71 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
#include <memory>
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
// resources, threadpools, etc. Clients are expected to construct a Runtime
// object through tensorflow::cc::RuntimeBuilder::Build, after setting any
// relevant configuration options. Many Tensorflow functions take a reference to
// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and
// may have different implementations depending on the runtime. For many of
// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the
// Runtime must outlive these objects.
class Runtime {
public:
// Runtime is movable, but not copyable.
Runtime(Runtime&&) = default;
Runtime& operator=(Runtime&&) = default;
private:
friend class RuntimeBuilder;
friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
// Deletes the currently wrapped TFE_Context, swaps it with ctx,
// and takes ownership of ctx.
void Reset(TFE_Context* ctx) { ctx_.reset(ctx); }
// Returns the TFE_Context that this object wraps. This object
// retains ownership of the pointer.
TFE_Context* GetTFEContext() const { return ctx_.get(); }
// Runtime is not copyable
Runtime(const Runtime&) = delete;
Runtime& operator=(const Runtime&) = delete;
struct TFEContextDeleter {
void operator()(TFE_Context* p) const { TFE_DeleteContext(p); }
};
std::unique_ptr<TFE_Context, TFEContextDeleter> ctx_;
};
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_

View File

@ -0,0 +1,86 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
// Use this to set configuration options, like threadpool size, etc.
class RuntimeBuilder {
public:
RuntimeBuilder() : options_(TFE_NewContextOptions()) {}
// If `use_tfrt` is true, we will use the new Tensorflow Runtime
// (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as
// our runtime implementation.
RuntimeBuilder& SetUseTFRT(bool use_tfrt);
// Build a Tensorflow Runtime.
//
// Params:
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// unique_ptr<tensorflow::cc::Runtime>.
std::unique_ptr<Runtime> Build(Status* status);
// RuntimeBuilder is movable, but not copyable.
RuntimeBuilder(RuntimeBuilder&&) = default;
RuntimeBuilder& operator=(RuntimeBuilder&&) = default;
private:
// RuntimeBuilder is not copyable
RuntimeBuilder(const RuntimeBuilder&) = delete;
RuntimeBuilder& operator=(const RuntimeBuilder&) = delete;
struct TFEContextOptionsDeleter {
void operator()(TFE_ContextOptions* p) const {
TFE_DeleteContextOptions(p);
}
};
std::unique_ptr<TFE_ContextOptions, TFEContextOptionsDeleter> options_;
};
inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) {
TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt);
return *this;
}
inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
// We can't use std::make_unique here because of its interaction with a
// private constructor: https://abseil.io/tips/134
return std::unique_ptr<Runtime>(new Runtime(result));
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_

View File

@ -0,0 +1,96 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
#include <memory>
#include <string>
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Status is a wrapper around an error code and an optional error message.
// The set of error codes are defined here:
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60
// Many Tensorflow APIs return a Status, or take a Status as an out parameter.
// Clients should check for status.ok() after calling these APIs, and either
// handle or propagate the error appropriately.
// TODO(bmzhao): Add a detailed code example before moving out of experimental.
class Status {
public:
// Create a success status
Status() : status_(TF_NewStatus()) {}
// Return the status code
TF_Code code() const;
// Returns the error message in Status.
std::string message() const;
// Returns the error message in Status.
bool ok() const;
// Record <code, msg> in Status. Any previous information is lost.
// A common use is to clear a status: SetStatus(TF_OK, "");
void SetStatus(TF_Code code, const std::string& msg);
// Status is movable, but not copyable.
Status(Status&&) = default;
Status& operator=(Status&&) = default;
private:
friend class RuntimeBuilder;
friend class Runtime;
friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {}
// Status is not copyable
Status(const Status&) = delete;
Status& operator=(const Status&) = delete;
// Returns the TF_Status that this object wraps. This object
// retains ownership of the pointer.
TF_Status* GetTFStatus() const { return status_.get(); }
struct TFStatusDeleter {
void operator()(TF_Status* p) const { TF_DeleteStatus(p); }
};
std::unique_ptr<TF_Status, TFStatusDeleter> status_;
};
inline TF_Code Status::code() const { return TF_GetCode(status_.get()); }
inline std::string Status::message() const {
return std::string(TF_Message(status_.get()));
}
inline bool Status::ok() const { return code() == TF_OK; }
inline void Status::SetStatus(TF_Code code, const std::string& msg) {
TF_SetStatus(status_.get(), code, msg.c_str());
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_

View File

@ -0,0 +1,175 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
#include <stddef.h>
#include <stdint.h>
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Tensor represents an n-dimensional array of values.
class Tensor {
public:
using DeleterCallback = std::function<void(void*, size_t)>;
// Constructs a Tensor from user provided buffer.
//
// Params:
// dtype - The dtype of the tensor's data.
// shape - A shape vector, where each element corresponds to the size of
// the tensor's corresponding dimension.
// data - Pointer to a buffer of memory to construct a Tensor out of.
// len - The length (in bytes) of `data`
// deleter - A std::function to be called when the Tensor no longer needs the
// memory in `data`. This can be used to free `data`, or
// perhaps decrement a refcount associated with `data`, etc.
// status - Set to OK on success and an error on failure.
// Returns:
// If an error occurred, status->ok() will be false, and the returned
// Tensor must not be used.
// TODO(bmzhao): Add Runtime as an argument to this function so we can swap to
// a TFRT backed tensor.
// TODO(bmzhao): Add benchmarks on overhead for this function; we can
// consider using int64_t* + length rather than vector.
static Tensor FromBuffer(TF_DataType dtype, const std::vector<int64_t>& shape,
void* data, size_t len, DeleterCallback deleter,
Status* status);
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
// we should offer a way to deep copy the tensor into a new tensor, which
// owns the underlying memory. This could be a .deepcopy()/clone() method.
// TODO(bmzhao): In the future, we want to relax the non-copyability
// constraint. To do so, we can add a C API function that acts like
// CopyFrom:
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
// Tensor is movable, but not copyable
Tensor(Tensor&&) = default;
Tensor& operator=(Tensor&&) = default;
// Returns the number of dimensions in the tensor. Can be -1, which represents
// unknown rank.
int dims() const;
// Returns the number of elements in in demension `d`.
// REQUIRES: `0 <= d < dims()`
int64_t dim_size(int d) const;
// Returns a pointer to the underlying data buffer.
void* data() const;
// Returns the data type of the tensor.
TF_DataType dtype() const;
// Returns the number of elements in the tensor. For a tensor with a partially
// defined shape, -1 means not fully defined.
int64_t num_elements() const;
// Returns the size of the underlying data in bytes.
size_t num_bytes() const;
private:
friend class TensorHandle;
friend class Runtime;
// Wraps a TF_Tensor. Takes ownership of handle.
explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {}
// Tensor is not copyable
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
// Returns the underlying TF_Tensor that this object wraps.
// This object retains ownership of the pointer.
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
struct DeleterStruct {
std::function<void(void*, size_t)> deleter;
};
static void DeleterFunction(void* memory, size_t len, void* deleter_struct) {
DeleterStruct* deleter = reinterpret_cast<DeleterStruct*>(deleter_struct);
deleter->deleter(memory, len);
delete deleter;
}
struct TFTensorDeleter {
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
};
std::unique_ptr<TF_Tensor, TFTensorDeleter> tensor_;
};
inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); }
inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); }
inline int64_t Tensor::dim_size(int d) const {
return TF_Dim(tensor_.get(), d);
}
inline TF_DataType Tensor::dtype() const {
return TF_TensorType(tensor_.get());
}
inline int64_t Tensor::num_elements() const {
return TF_TensorElementCount(tensor_.get());
}
inline size_t Tensor::num_bytes() const {
return TF_TensorByteSize(tensor_.get());
}
inline Tensor Tensor::FromBuffer(TF_DataType dtype,
const std::vector<int64_t>& shape, void* data,
size_t len, DeleterCallback deleter,
Status* status) {
// Credit to apassos@ for this technique:
// Despite the fact that our API takes a std::function deleter, we are able
// to maintain ABI stability because:
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
// 2. DeleterFunction is defined in the same build artifact that constructed
// the std::function (so there isn't confusion about std::function ABI).
// Note that 2. is satisifed by the fact that this is a header-only API, where
// the function implementations are inline.
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len,
&DeleterFunction, deleter_struct);
if (tensor == nullptr) {
status->SetStatus(TF_INVALID_ARGUMENT,
"Failed to create tensor for input buffer");
return Tensor(nullptr);
}
return Tensor(tensor);
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#include <memory>
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// An opaque representation of a tensor computed/managed by the Tensorflow
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
// to tensors placed in memory of different devices or remote address spaces.
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
// from it.
class TensorHandle {
public:
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
// status->ok() will be false, and the returned Tensor must not be used.
Tensor Resolve(Status* status);
// Constructs a TensorHandle from a Tensor. If an error occurred,
// status->ok() will be false, and the returned TensorHandle must not be used.
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
Status* status);
// TensorHandle is movable, and not copyable
TensorHandle(TensorHandle&&) = default;
TensorHandle& operator=(TensorHandle&&) = default;
private:
// Wraps a TFE_TensorHandle. Takes ownership of handle.
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
// TensorHandle is not copyable
TensorHandle(const TensorHandle&) = delete;
TensorHandle& operator=(const TensorHandle&) = delete;
// Returns the underlying TFE_TensorHandle that this object wraps.
// This object retains ownership of the pointer.
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
// and takes ownership of handle.
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
struct TFETensorHandleDeleter {
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
};
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
};
inline Tensor TensorHandle::Resolve(Status* status) {
TF_Tensor* tensor =
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
if (!status->ok()) {
return Tensor(nullptr);
}
return Tensor(tensor);
}
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
const Runtime& runtime,
Status* status) {
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
if (!status->ok()) {
return TensorHandle(nullptr);
}
return TensorHandle(tensor_handle);
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_

View File

@ -0,0 +1,50 @@
# Tests for the C++ header-only base types.
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "tensor_types_test_util",
testonly = True,
hdrs = ["tensor_types_test_util.h"],
deps = [
"//tensorflow/c:tf_datatype",
],
)
tf_cc_test(
name = "tensor_test",
srcs = [
"tensor_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "tensorhandle_test",
srcs = [
"tensorhandle_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
"//tensorflow/cc/experimental/base/public:tensorhandle",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,163 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace {
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorTest : public ::testing::Test {};
TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
Status status;
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
EXPECT_EQ(tensor.dtype(), dtype);
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), 1);
}
template <typename T>
class Construct1DTensorTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 1 vector.
std::vector<int64_t> shape;
shape.push_back(value.size());
Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
template <typename T>
class Construct2DTensorTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
bool done = false;
Status status;
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
{
// data_vector is a rank 1 tensor.
std::vector<int64_t> shape;
shape.push_back(data_vector.size());
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
done = true;
};
Tensor tensor =
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
/*data=*/data_vector.data(),
/*len=*/data_vector.size() * sizeof(int32_t),
/*deleter=*/callback, &status);
ASSERT_TRUE(status.ok()) << status.message();
}
// At this point, tensor has been destroyed, and the deleter callback should
// have run.
EXPECT_TRUE(done);
}
} // namespace

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
namespace tensorflow {
// Each of the following struct types have two members: a kDType that
// corresponds to a TF_Datatype enum value, and a typedef "type"
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
struct DoubleType {
using type = double;
static constexpr TF_DataType kDType = TF_DOUBLE;
};
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
using tensorflow::experimental::cc::TensorHandle;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
// verify the expected dims, dtype, value, num bytes, and num elements.
TYPED_TEST(ConstructScalarTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
Tensor original_tensor =
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
EXPECT_EQ(tensor.dtype(), dtype);
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), 1);
}
template <typename T>
class Construct1DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 1 vector.
std::vector<int64_t> shape;
shape.push_back(value.size());
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
template <typename T>
class Construct2DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
} // namespace
} // namespace tensorflow

View File

@ -4,7 +4,6 @@
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"if_ios",
"if_mobile",
"if_not_mobile",
"tf_cc_test",
@ -85,7 +84,7 @@ cc_library(
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
]) + if_android([
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
]),
)

View File

@ -0,0 +1,58 @@
# Experimental C++ SavedModel Header Only APIs. See RFC
# https://github.com/tensorflow/community/pull/207
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "concrete_function",
hdrs = [
"concrete_function.h",
],
deps = [
":function_metadata",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/cc/experimental/base/public:status",
],
)
cc_library(
name = "concrete_function_list",
hdrs = [
"concrete_function_list.h",
],
deps = [
":concrete_function",
"//tensorflow/c/experimental/saved_model/public:concrete_function_list",
],
)
cc_library(
name = "function_metadata",
hdrs = [
"function_metadata.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/public:function_metadata",
],
)
cc_library(
name = "saved_model_api",
hdrs = [
"saved_model_api.h",
],
deps = [
":concrete_function",
":concrete_function_list",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:status",
],
)

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
class ConcreteFunction final {
public:
// TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since
// it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle
// Returns FunctionMetadata associated with this ConcreteFunction.
const FunctionMetadata* GetFunctionMetadata();
private:
friend class SavedModelAPI;
friend class ConcreteFunctionList;
// TODO(bmzhao): Consider adding a macro for wrapping/unwrapping
// when moving out of experimental.
static ConcreteFunction* wrap(TF_ConcreteFunction* p) {
return reinterpret_cast<ConcreteFunction*>(p);
}
static TF_ConcreteFunction* unwrap(ConcreteFunction* p) {
return reinterpret_cast<TF_ConcreteFunction*>(p);
}
};
inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this)));
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#include <vector>
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of
// ConcreteFunction pointers to a std::vector.
class ConcreteFunctionList {
public:
// Converts this object to a std::vector<ConcreteFunction*>
std::vector<ConcreteFunction*> ToVector();
private:
friend class SavedModelAPI;
// Wraps a TF_ConcreteFunctionList. Takes ownership of list.
explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {}
struct TFConcreteFunctionListDeleter {
void operator()(TF_ConcreteFunctionList* p) const {
TF_DeleteConcreteFunctionList(p);
}
};
std::unique_ptr<TF_ConcreteFunctionList, TFConcreteFunctionListDeleter> list_;
};
inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
int size = TF_ConcreteFunctionListSize(list_.get());
std::vector<ConcreteFunction*> result;
result.reserve(size);
for (int i = 0; i < size; ++i) {
result.push_back(
ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i)));
}
return result;
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
#include <memory>
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// FunctionMetadata stores additional function information, including
// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions),
// a valid function path (for TF2-based ConcreteFunctions), and
// the types + number of inputs and outputs.
class FunctionMetadata final {
// TODO(bmzhao): Add getters here as necessary.
private:
friend class ConcreteFunction;
static FunctionMetadata* wrap(TF_FunctionMetadata* p) {
return reinterpret_cast<FunctionMetadata*>(p);
}
static TF_FunctionMetadata* unwrap(FunctionMetadata* p) {
return reinterpret_cast<TF_FunctionMetadata*>(p);
}
};
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -0,0 +1,162 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models
// (https://www.tensorflow.org/guide/saved_model) and execute saved
// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion.
// See RFC 207
// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md)
// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added.
class SavedModelAPI {
public:
// Load a SavedModel from `dirname`.
//
// Params:
// saved_model_path - A directory filepath that the SavedModel is at.
// runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the
// returned TF_SavedModel pointer.
// tags - Optional set of tags. If tags = nullptr, we expect the SavedModel
// to contain a single Metagraph (as for those exported from TF2's
// `tf.saved_model.save`). If tags != nullptr, we load the metagraph
// matching the tags:
// https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr.
static std::unique_ptr<SavedModelAPI> Load(
const std::string& saved_model_path, const Runtime& runtime,
Status* status, const std::unordered_set<std::string>* tags = nullptr);
// Retrieve a function from the TF2 SavedModel via function path.
//
// Params:
// function_path - A string containing the path from the root saved python
// object to a tf.function method.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
// is bound to SavedModelAPI it was loaded from.
ConcreteFunction* GetConcreteFunction(const std::string& function_path,
Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
// Params:
// signature_def_key - String key of SignatureDef map of a SavedModel:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
// is bound to SavedModelAPI it was loaded from.
ConcreteFunction* GetSignatureDefFunction(const std::string& function_path,
Status* status);
// Lists all Conrete Functions available from the SavedModel.
std::vector<ConcreteFunction*> ListFunctions();
// SavedModelAPI is movable, but not copyable.
SavedModelAPI(SavedModelAPI&&) = default;
SavedModelAPI& operator=(SavedModelAPI&&) = default;
private:
SavedModelAPI(const SavedModelAPI&) = delete;
SavedModelAPI& operator=(const SavedModelAPI&) = delete;
explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {}
struct TFSavedModelDeleter {
void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); }
};
std::unique_ptr<TF_SavedModel, TFSavedModelDeleter> saved_model_;
};
inline std::unique_ptr<SavedModelAPI> SavedModelAPI::Load(
const std::string& saved_model_path, const Runtime& runtime, Status* status,
const std::unordered_set<std::string>* tags) {
TF_SavedModel* saved_model = nullptr;
if (tags == nullptr) {
saved_model =
TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(),
status->GetTFStatus());
} else {
std::vector<const char*> tags_vector;
tags_vector.reserve(tags->size());
for (const std::string& tag : *tags) {
tags_vector.push_back(tag.c_str());
}
saved_model = TF_LoadSavedModelWithTags(
saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(),
tags_vector.size(), status->GetTFStatus());
}
if (!status->ok()) {
return nullptr;
}
// We can't use std::make_unique here because of its interaction with a
// private constructor: https://abseil.io/tips/134
return std::unique_ptr<SavedModelAPI>(new SavedModelAPI(saved_model));
}
inline ConcreteFunction* SavedModelAPI::GetConcreteFunction(
const std::string& function_path, Status* status) {
TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction(
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
return ConcreteFunction::wrap(function);
}
inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction(
const std::string& function_path, Status* status) {
TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction(
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
return ConcreteFunction::wrap(function);
}
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
return list.ToVector();
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -0,0 +1,22 @@
# Tests for the C++ header-only SavedModelAPI.
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_test(
name = "saved_model_api_test",
srcs = [
"saved_model_api_test.cc",
],
deps = [
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/saved_model/experimental/public:saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h"
#include <memory>
#include <string>
#include <unordered_set>
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::SavedModelAPI;
using tensorflow::experimental::cc::Status;
constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
kTestData, saved_model_dir);
}
// This value parameterized test allows us to test both TFRT
// and non TFRT runtimes.
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
Status status;
RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"};
std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
}
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
Status status;
RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
CPPSavedModelAPITest, ::testing::Bool());
} // namespace

View File

@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
std::vector<string> dim_vars;
string dim_sizes, indices;
int count = 1;
if (shape.rank() == 0 ||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
dim_sizes = "[1]";
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
dim_vars.push_back(absl::StrCat("size_t dim", dim));
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
indices += absl::StrCat("[dim", dim, "]");
count *= shape.dimensions(dim);
}
}
rewrites->push_back({"{{I}}", absl::StrCat(i)});
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
return Status::OK();
}
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
arg_data({{I}}))){{INDICES}};
}
int arg{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int arg{{NAME}}_count() const {
return {{COUNT}};
}
)";
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.feed(i).name().empty()) {
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
result_data({{I}}))){{INDICES}};
}
int result{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int result{{NAME}}_count() const {
return {{COUNT}};
}
)";
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.fetch(i).name().empty()) {
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
arg_data({{I}}))){{INDICES}};
}
int var_{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int var_{{NAME}}_count() const {
return {{COUNT}};
}
)";
const tf2xla::Variable& var = config.variable(i - config.feed_size());
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");

View File

@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1][2]>(
arg_data(0)))[dim0][dim1];
}
int arg0_size() const {
return 2 * sizeof(float);
}
int arg0_count() const {
return 2;
}
void set_arg_myfeed_data(const void* data) {
set_arg_data(0, data);
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1][2]>(
arg_data(0)))[dim0][dim1];
}
int arg_myfeed_size() const {
return 2 * sizeof(float);
}
int arg_myfeed_count() const {
return 2;
}
void set_arg1_data(const void* data) {
set_arg_data(1, data);
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::int64(*)[3][4]>(
arg_data(1)))[dim0][dim1];
}
int arg1_size() const {
return 12 * sizeof(tensorflow::int64);
}
int arg1_count() const {
return 12;
}
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
result_data(0)))[dim0][dim1];
}
int result0_size() const {
return 30 * sizeof(tensorflow::uint32);
}
int result0_count() const {
return 30;
}
tensorflow::uint32* result_myfetch_data() {
return static_cast<tensorflow::uint32*>(result_data(0));
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
result_data(0)))[dim0][dim1];
}
int result_myfetch_size() const {
return 30 * sizeof(tensorflow::uint32);
}
int result_myfetch_count() const {
return 30;
}
// Methods for managing variable buffers. Buffers are in row-major order.
//
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1]>(
arg_data(2)))[0];
}
int var_myvar_readonly_size() const {
return 1 * sizeof(float);
}
int var_myvar_readonly_count() const {
return 1;
}
void set_var_myvar_data(float* data) {
set_arg_data(3, data);
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1]>(
arg_data(3)))[0];
}
int var_myvar_size() const {
return 1 * sizeof(float);
}
int var_myvar_count() const {
return 1;
}
void set_var_myvar2_data(tensorflow::int32* data) {
set_arg_data(4, data);
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::int32(*)[5]>(
arg_data(4)))[dim0];
}
int var_myvar2_size() const {
return 5 * sizeof(tensorflow::int32);
}
int var_myvar2_count() const {
return 5;
}
private:
// Number of buffers for the compiled computation.

View File

@ -20,7 +20,7 @@ load(
"tf_cc_test",
"tf_copts",
)
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
def tf_library(
name,
@ -42,7 +42,8 @@ def tf_library(
mlir_components = "None",
deps = None,
tags = []):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
math enabled on cpu.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
@ -187,7 +188,9 @@ def tf_library(
# `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
flags = tfcompile_extra_flags() + flags
target_cpu = tfcompile_target_cpu()
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
flags = extra_flags + flags
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
@ -207,6 +210,15 @@ def tf_library(
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
default_fast_math_xla_flags = ("XLA_FLAGS='" +
"--xla_cpu_enable_fast_math=true " +
"--xla_cpu_fast_math_honor_nans=false " +
"--xla_cpu_fast_math_honor_infs=false " +
"--xla_cpu_fast_math_honor_functions=false " +
"--xla_cpu_fast_math_honor_division=false " +
"--xla_cpu_enable_fast_min_max=true " +
"$${XLA_FLAGS:-}' ")
native.genrule(
name = ("gen_" + name),
srcs = srcs,
@ -216,6 +228,7 @@ def tf_library(
function_object_file,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
@ -256,6 +269,7 @@ def tf_library(
session_module_pb,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +

View File

@ -67,6 +67,8 @@ int main(int argc, char** argv) {
flags.entry_point = "entry";
flags.debug_info_path_begin_marker = "";
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
// generally the preferred case.
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::AppendDebugOptionsFlags(&flag_list);

View File

@ -251,7 +251,7 @@ cc_library(
visibility = [":friends"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:graph",
@ -505,6 +505,7 @@ cc_library(
name = "shape_inference",
srcs = ["shape_inference.cc"],
hdrs = ["shape_inference.h"],
visibility = [":friends"],
deps = [
":shape_inference_helpers",
"//tensorflow/compiler/xla:statusor",

View File

@ -2034,6 +2034,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorArraySplitV3",
"TensorArrayV3",
"TensorArrayWriteV3",
"TensorListConcatV2",
"TensorListElementShape",
"TensorListFromTensor",
"TensorListGather",
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorListPushBack",
"TensorListReserve",
"TensorListSetItem",
"TensorListSplit",
"TensorListStack",
"TensorScatterAdd",
"TensorScatterSub",
@ -2078,6 +2080,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"XlaSend",
"XlaSharding",
"XlaSort",
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaWhile",
"_Arg",

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/public/version.h"
@ -277,29 +278,25 @@ Status XlaCompilationCache::CompileSingleOp(
const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
return arg.kind == XlaCompiler::Argument::kParameter;
});
bool are_args_supported =
absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
return arg.kind == XlaCompiler::Argument::kConstant ||
arg.kind == XlaCompiler::Argument::kParameter;
});
const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge();
// Use MLIR bridge if all the arguments are parameters.
// TODO(hinsu): Support other argument types instead of silently falling
// back to the XLA compiler.
if (!are_params || !use_mlir) {
// TODO(b/155596779): Understand the source of other argument types and
// depending on the source either support those or avoid these codepath.
if (!use_mlir || !are_args_supported) {
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
}
absl::InlinedVector<TensorShape, 4> arg_shapes;
arg_shapes.reserve(args.size());
for (const XlaCompiler::Argument& arg : args) {
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
}
GraphDebugInfo debug_info;
return CompileGraphToXlaHlo(
*graph, {arg_shapes.data(), arg_shapes.size()},
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
*graph, {args.data(), args.size()}, options.device_type.type_string(),
compile_options.use_tuple_arg, *options.flib_def, debug_info,
options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -180,12 +180,10 @@ class XlaAssignVariableOp : public OpKernel {
data::MakeIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \
data::DeleteIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \
data::DeleteIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
data::IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \

View File

@ -104,6 +104,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,

View File

@ -31,7 +31,7 @@ filegroup(
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)
@ -260,6 +260,41 @@ cc_library(
],
)
cc_library(
name = "tftext_utils",
srcs = [
"utils/tftext_utils.cc",
],
hdrs = [
"utils/tftext_utils.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
tf_cc_test(
name = "tftext_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "stateful_ops_utils",
srcs = [
@ -320,6 +355,7 @@ cc_library(
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
":tftext_utils",
":validators",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
@ -523,7 +559,6 @@ cc_library(
"@flatbuffers",
"@llvm-project//llvm:analysis",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
@ -696,9 +731,9 @@ cc_library(
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
@ -710,6 +745,8 @@ tf_cc_binary(
name = "flatbuffer_translate",
deps = [
":flatbuffer_translate_registeration",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
],
)
@ -758,6 +795,13 @@ tf_cc_binary(
":tf_tfl_passes",
":tf_tfl_translate_cl_options",
":tf_to_tfl_flatbuffer",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:protos_all_cc",
@ -765,11 +809,6 @@ tf_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)
@ -781,17 +820,19 @@ tf_cc_binary(
deps = [
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
],
)

View File

@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
std::string node_def_str;
if (!node_def.SerializeToString(&node_def_str)) {
return emitError(loc, "failed to serialize tensorflow node_def"),
llvm::None;
}
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
return builder_.CreateVector(flex_builder->GetBuffer());
}
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
size_t map_start = flex_builder->StartMap();
for (const auto& pair : node_def.attr()) {
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
std::sort(attrs.begin(), attrs.end(),
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
for (const Item& pair : attrs) {
const char* key = pair.first.c_str();
const auto& attr = pair.second;
const ::tensorflow::AttrValue& attr = pair.second;
switch (attr.value_case()) {
case ::tensorflow::AttrValue::kS:
flex_builder->String(key, attr.s());
@ -1017,10 +1016,10 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
inst->getName().print(os);
// Print out attributes except for large elementsattributes (which should
// rarely be the cause why the legalization didn't happen).
if (!inst->getAttrList().getAttrs().empty()) {
if (!inst->getMutableAttrDict().getAttrs().empty()) {
os << " {";
bool first = true;
for (auto& named_attr : inst->getAttrList().getDictionary()) {
for (auto& named_attr : inst->getAttrDictionary()) {
os << (!first ? ", " : "");
first = false;
named_attr.first.print(os);

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"

File diff suppressed because it is too large Load Diff

View File

@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
saved_model_exported_names.begin(), saved_model_exported_names.end());
absl::Span<std::string> exported_names(exported_names_in_vector);
if (exported_names.size() != 1) {
return errors::Unimplemented("Only support a single exported name.");
}
TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,

View File

@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_STRING;
case toco::IODataType::BOOL:
return DT_BOOL;
case toco::IODataType::COMPLEX64:
return DT_COMPLEX64;
default:
return DT_INVALID;
}
@ -252,7 +254,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
return errors::InvalidArgument("Failed to open file in ", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));

View File

@ -32,6 +32,7 @@ namespace mlir {
namespace quant {
constexpr int k8Bits = 8;
constexpr int k32Bits = 32;
constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
@ -39,20 +40,20 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
i8_ = IntegerType::get(k8Bits, ctx_);
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
i32_ = IntegerType::get(k32Bits, ctx_);
i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
any_ = AnyQuantizedType();
qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
assert(qi8n_ == qi8n_);
}
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
auto kernel_specs_it = specs_.find(op.logical_kernel());
Optional<KernelSpec> DeviceTarget::GetKernelSpec(
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
auto kernel_specs_it = specs_.find(kernel);
if (kernel_specs_it == specs_.end()) return llvm::None;
KernelSpecs::Signature signature;
signature.reserve(op.input_specs().size() + op.output_specs().size());
AppendToSignature(op.input_specs(), &signature);
AppendToSignature(op.output_specs(), &signature);
return kernel_specs_it->getValue().Find(signature);
}
@ -62,31 +63,38 @@ ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
return kernel_specs_it->second.GetDecomposeFn();
}
void DeviceTarget::AppendToSignature(Type spec,
KernelSpecs::Signature* signature) {
if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
signature->push_back(AnyQuantizedType::get(
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
} else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
signature->push_back(any);
} else { // float
signature->push_back(AnyQuantizedType());
}
}
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
}
namespace ph = std::placeholders;
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleConstraintType constraint) {
return specs_[kernel].Add(signature, {constraint, {}});
}
void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const {
for (auto attr : specs_attr) {
Type spec = attr.cast<TypeAttr>().getValue();
if (auto quant = spec.dyn_cast<UniformQuantizedType>()) {
signature->push_back(AnyQuantizedType::get(
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
} else if (auto any = spec.dyn_cast<AnyQuantizedType>()) {
signature->push_back(any);
} else { // float
signature->push_back({});
}
if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure();
switch (constraint) {
case ScaleConstraintType::OutputInputSameScale:
specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale,
ph::_1, ph::_2, ph::_3, ph::_4));
return success();
default:
return failure();
}
}
@ -119,7 +127,7 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
input_multipliers->append(3, kUnitQuantizedMultiplier);
// output multipliers
double real_multiplier = o_spec.getScale() / scale_product;
double real_multiplier = scale_product / o_spec.getScale();
output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
// output ranges
@ -134,5 +142,40 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
return success();
}
LogicalResult DeviceTarget::DecomposeSameScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges) {
auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
if (!rop) return failure();
// input multipliers
for (int i = 0; i < op->getNumOperands(); ++i) {
input_multipliers->push_back(kUnitQuantizedMultiplier);
}
// output multipliers
for (int i = 0; i < op->getNumResults(); ++i) {
output_multipliers->push_back(kUnitQuantizedMultiplier);
}
auto o_spec = rop.output_specs()[0]
.cast<TypeAttr>()
.getValue()
.dyn_cast<quant::UniformQuantizedType>();
if (!o_spec) return failure();
// output ranges
auto min = rop.getAttrOfType<FloatAttr>("min");
auto max = rop.getAttrOfType<FloatAttr>("max");
output_ranges->push_back(quant::CalculateQuantizedRange(
o_spec.getScale(), o_spec.getZeroPoint(),
(min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
(max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
return success();
}
} // namespace quant
} // namespace mlir

View File

@ -134,11 +134,18 @@ class DeviceTarget {
explicit DeviceTarget(MLIRContext* ctx);
// Retrieves the kernel spec for the quant region op.
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
Optional<KernelSpec> GetKernelSpec(
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const;
// Retrieves the scale decomposition function for the quant region op.
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
// converts specification to signature:
// - UniformedQuantizedType -> AnyQuantizedType
// - AnyQuantizedType (int) -> AnyQuantizedType
// - Float -> {}
static void AppendToSignature(Type spec, KernelSpecs::Signature* signature);
protected:
// Adds the kernel spec with the custom scale function for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
@ -154,13 +161,6 @@ class DeviceTarget {
// added before.
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
// converts specification to signature:
// - UniformedQuantizedType -> AnyQuantizedType
// - AnyQuantizedType (int) -> AnyQuantizedType
// - Float -> {}
void AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const;
// For "mulmat->add" type of kernels, convert the scales of all the ports to
// multipliers.
static LogicalResult DecomposeMultiplyAccumulateScale(
@ -168,11 +168,17 @@ class DeviceTarget {
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges);
// For "reshape" type of kernels.
static LogicalResult DecomposeSameScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges);
// A set of parameters are required to build the signatures.
FloatType f32_;
IntegerType i8_;
int64_t i8_min_, i8_max_;
AnyQuantizedType any_, qi8_, qi8n_;
IntegerType i8_, i32_;
int64_t i8_min_, i8_max_, i32_min_, i32_max_;
AnyQuantizedType any_, qi8_, qi8n_, qi32_;
private:
// Maps the kernel names to all the available kernels.

View File

@ -72,5 +72,6 @@ tf_cc_binary(
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
],
)

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
namespace lite {
@ -38,7 +39,9 @@ namespace lite {
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`
@ -72,15 +75,18 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes
PassManager pm(module->getContext());
TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8;
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
quant_specs.post_training_quantization = true;
quant_specs.disable_per_channel = disable_per_channel;
bool emit_adaptor = false;
auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) {
quant_specs.inference_type = tensorflow::DT_QUINT8;
} else if (input_tf_type == tensorflow::DT_UINT8 ||
input_tf_type == tensorflow::DT_INT8 ||
input_tf_type == tensorflow::DT_INT16) {
quant_specs.inference_type = input_tf_type;
}
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));

View File

@ -26,12 +26,15 @@ namespace mlir {
namespace lite {
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
// The `input_type` and `output_type` can be float32/qint8/int8.
// The `input_type`, `output_type` and `inference_type` can be
// float32/qint8/int8/int16.
// Return partially quantized model if `fully_quantize` is false.
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter);
} // namespace lite

View File

@ -46,7 +46,9 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
tflite::StderrReporter error_reporter;
return mlir::lite::QuantizeModel(
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
*model, tflite::TensorType_INT8, tflite::TensorType_INT8,
tflite::TensorType_INT8, {},
/*disable_per_channel=*/false,
/*fully_quantize=*/true, builder, &error_reporter);
}

View File

@ -46,6 +46,12 @@ struct QuantizationSpecs {
// post-training quantization. We need to deprecate the `weight_quantization`.
bool post_training_quantization = false;
// When set to true, quantization will be done per-tensor. Currently, this
// option is only valid when the quantization parameters need to be created by
// scanning the constant content (post-training quantization or QAT without
// weight FakeQuant).
bool disable_per_channel = false;
// The node type when the model is exported. Currently this is limited to
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
// `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,
@ -84,7 +90,7 @@ struct QuantizationSpecs {
bool RunWeightQuantization() const { return weight_quantization; }
// Whether this inference type represents a signed storage type.
bool IsSignedInferenceType() {
bool IsSignedInferenceType() const {
switch (inference_type) {
case tensorflow::DT_QUINT8:
case tensorflow::DT_QUINT16:
@ -96,7 +102,7 @@ struct QuantizationSpecs {
// Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type.
int64_t GetQuantizationTypeWidth() {
int64_t GetQuantizationTypeWidth() const {
switch (inference_type) {
case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8:

View File

@ -64,10 +64,23 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
return all_ops;
}
KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) {
KernelSpecs::Signature signature;
signature.reserve(op.input_specs().size() + op.output_specs().size());
for (int i = 0; i < op.getNumOperands(); ++i) {
DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature);
}
for (int i = 0; i < op.getNumResults(); ++i) {
DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature);
}
return signature;
}
LogicalResult QuantizeContext::Handle(
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed) {
auto spec = target_spec_.GetKernelSpec(op);
auto signature = GetSignature(op);
auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature);
if (!spec.hasValue()) {
op.emitWarning(
"Couldn't find kernel from the registeration for quantization.");

View File

@ -107,6 +107,9 @@ class QuantizeContext {
return states_manager_.GetOperandParams(op, index);
}
// Return the signature of the op.
KernelSpecs::Signature GetSignature(QuantizeRegionOp op);
// A heuristic to get quantization parameters satisfies the same scale
// constraints:
// - If there are immutable states,

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -35,6 +36,7 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
}
};
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
template <typename RQ>
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
explicit FoldTrivalRequantizeOp(MLIRContext* context)
: OpRewritePattern<RQ>(context, 1) {}
LogicalResult matchAndRewrite(RQ op,
PatternRewriter& rewriter) const override {
Value pre_quantized = op.input();
auto pre_quantized_type =
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
if (!pre_quantized_type) return failure();
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
llvm::SmallVector<Type, 4> new_output_types;
for (auto result : def->getResults()) {
result.getUsers().begin()->dump();
op.dump();
if (result.hasOneUse() && *result.getUsers().begin() == op) {
new_output_types.push_back(op.qtype());
} else {
new_output_types.push_back(result.getType());
}
}
// Remove this rescale op.
rewriter.replaceOp(op, {pre_quantized});
// Replace the output scale of the preceding op.
rewriter.setInsertionPointAfter(def);
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
def->getOperands(), new_output_types,
def->getAttrs());
Operation* new_op = rewriter.createOperation(new_state);
rewriter.replaceOp(def, new_op->getResults());
return success();
}
};
// Given a quantized type `input`, magnifying its scales by the factor stored in
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
// dimension size of `input` or isn't floating-point, nullptr will be returned.

View File

@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
return %1 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacent
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE]]
}
// Checks that tfl.reshape should be removed if its output has more than one
@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32>
return %3 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %2 = addf %0, %1
// CHECK: return %2
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]]
// CHECK: return %[[RESULT]]
}
// Checks that tfl.reshape should be kept if its output has more than one
@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %cst_0 = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %0, %1
// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]]
}
// Checks that tfl.reshape should be removed if its output type is the same

View File

@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<9> : tensor<i32>
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<7> : tensor<i32>
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32>
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_int
@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape
@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>,
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_splat_float
@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_float
@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_same_shape
@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @rank
func @rank() -> tensor<1xi32> {
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
// CHECK-LABEL: @rank_input_known_rank
func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
// CHECK-LABEL: @pseudo_const
func @pseudo_const() -> tensor<i32> {
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: return %[[CST]]
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
return %0 : tensor<i32>
}
@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> {
%cst_1 = constant dense<4> : tensor<i32>
%cst_2 = constant dense<1> : tensor<i32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> {
%cst_1 = constant dense<4.0> : tensor<f32>
%cst_2 = constant dense<1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> {
%cst_1 = constant dense<-4.0> : tensor<f32>
%cst_2 = constant dense<-1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> {
%cst_1 = constant dense<7.0> : tensor<f32>
%cst_2 = constant dense<1.5> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[0, 1]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> {
%cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
%cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
return %0 : tensor<4x2x3xi32>
}
@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> {
%87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %87 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic
@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> {
return %2 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concat_2_tensors_1_empty
@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> {
%3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32>
return %3 : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return [[cst]] : tensor<2xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return %[[CST]] : tensor<2xi32>
}
// CHECK-LABEL: @concat_3_tensors_1_empty
@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> {
%3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
// CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %0 : tensor<?xi32>
}
@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32>
return %0 : tensor<2x2x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsMiddleDim
@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32>
return %0 : tensor<1x4x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsLastDim
@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32>
return %0 : tensor<1x2x6xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_different_rank
@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
return %0 : tensor<1x2x2xf32>
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %[[CST]]
}

View File

@ -0,0 +1,101 @@
# RUN: tf_tfl_translate -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=2,5,3:3,7 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-output-arrays=MatMul -output-mlir %s -o - 2>&1 | FileCheck %s
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
node {
name: "Placeholder_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
dim {
size: 7
}
}
}
}
}
node {
name: "MatMul"
op: "BatchMatMulV2"
input: "Placeholder"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "adj_x"
value {
b: false
}
}
attr {
key: "adj_y"
value {
b: false
}
}
}
versions {
producer: 175
}
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
# CHECK: %[[VAL_5:.*]] = constant unit
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
# CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
# CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_9]]) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<1x3x7xf32>
# CHECK: %[[VAL_15:.*]] = "tfl.slice"(%[[VAL_14]], %[[VAL_8]], %[[VAL_9]]) : (tensor<1x3x7xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x3x7xf32>
# CHECK: %[[VAL_16:.*]] = "tfl.reshape"(%[[VAL_15]], %[[VAL_4]]) : (tensor<1x3x7xf32>, tensor<2xi32>) -> tensor<3x7xf32>
# CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_16]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32>
# CHECK: %[[VAL_18:.*]] = "tfl.fully_connected"(%[[VAL_11]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
# CHECK: %[[VAL_19:.*]] = "tfl.fully_connected"(%[[VAL_13]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
# CHECK: %[[VAL_20:.*]] = "tfl.pack"(%[[VAL_18]], %[[VAL_19]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32>
# CHECK: return %[[VAL_20]] : tensor<2x5x7xf32>
# CHECK: }

View File

@ -1,15 +1,15 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Ensure lstm roundtrip exactly
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
// CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: return %[[RES0]]
}

View File

@ -0,0 +1,14 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s --dump-input-on-failure
module {
func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
%0 = "tf.op1"(%arg0) : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>)
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<?xi64>
%2:2 = "tf.op2"(%arg0, %1) : (tensor<1x!tf.string>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<?xi64>)
return %2#0, %2#1 : tensor<?x!tf.string>, tensor<?xi64>
}
// CHECK: func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
// CHECK: "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
}

View File

@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2I64Axis
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@ -1213,15 +1222,14 @@ func @resize_nearest_neighbor(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi3
%0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: resize_nearest_neighbor
// CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
// CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
}
// Note: half_pixel_centers isn't supported by TFLite, so it's not legalized.
func @resize_nearest_neighbor_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
%0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: resize_nearest_neighbor_with_half_pixel_centers
// CHECK: "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true}
// CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
}
func @sparse_to_dense_with_scalar_sparse_indices(%arg0: tensor<i32>, %arg1: tensor<3xi32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x?x?xf32> {
@ -1497,3 +1505,27 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
return %0 : tensor<10x17xf32>
// CHECK-LABEL: matmul_batch
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
}
func @matmul_batchv2(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
(tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32>
return %0 : tensor<2x10x17xf32>
// CHECK-LABEL: matmul_batchv2
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32>
}
func @matmul_batchv2_unknown_dim(%arg0: tensor<?x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<?x10x17xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
(tensor<?x10x15xf32>, tensor<15x17xf32>) -> tensor<?x10x17xf32>
return %0 : tensor<?x10x17xf32>
// CHECK-LABEL: matmul_batchv2_unknown_dim
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<?x10x15xf32>, tensor<15x17xf32>) -> tensor<?x10x17xf32>
}

View File

@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3 ],

View File

@ -1,6 +1,6 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
@ -72,21 +72,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 10,
// CHECK-NEXT: name: "arg9",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 11,
// CHECK-NEXT: name: "arg10",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 12,
// CHECK-NEXT: name: "arg11",
// CHECK-NEXT: quantization: {
@ -100,21 +100,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 14,
// CHECK-NEXT: name: "arg13",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 15,
// CHECK-NEXT: name: "arg14",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 16,
// CHECK-NEXT: name: "arg15",
// CHECK-NEXT: quantization: {
@ -128,7 +128,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 18,
// CHECK-NEXT: name: "arg17",
// CHECK-NEXT: quantization: {
@ -261,9 +261,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}

Some files were not shown because too many files have changed in this diff Show More