Merge branch 'master' into 42129-tf.image.crop_and_resize

This commit is contained in:
Mihai Maruseac 2020-10-18 11:19:38 -07:00 committed by GitHub
commit 1816c43041
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4626 changed files with 193539 additions and 96385 deletions

View File

@ -5,6 +5,7 @@
# Android options: # Android options:
# android: # android:
# android_arm: # android_arm:
# android_arm64:
# android_x86: # android_x86:
# android_x86_64: # android_x86_64:
# #
@ -46,10 +47,6 @@
# using_cuda: CUDA is available to build system. # using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support. # cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm). # rocm: Build with AMD GPU support (rocm).
# sycl: Build with SYCL support.
# sycl_nodouble:
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support. # mkl: Enable full mkl support.
# tensorrt: Enable Tensorrt support. # tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support. # ngraph: Enable ngraph support.
@ -89,6 +86,7 @@
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds. # release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds. # release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds. # release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds. # release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds. # release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
@ -161,13 +159,11 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build. # environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool. # config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true
build:mkl_threadpool --define=build_with_mkl_opensource=true build:mkl_threadpool --define=build_with_mkl_opensource=true
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt build:mkl_threadpool -c opt
@ -175,10 +171,15 @@ build:mkl_threadpool -c opt
# Config setting to build with oneDNN and without the binary blob # Config setting to build with oneDNN and without the binary blob
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only -c opt build:mkl_opensource_only -c opt
# Config setting to build with oneDNN for Arm.
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_aarch64 --define=build_with_mkl_opensource=true
build:mkl_aarch64 -c opt
# This config refers to building with CUDA available. It does not necessarily # This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels. # mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true build:using_cuda --define=using_cuda=true
@ -216,19 +217,6 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --action_env TF_NEED_ROCM=1 build:rocm --action_env TF_NEED_ROCM=1
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl --define=using_sycl=true
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
build:sycl_nodouble --config=sycl
build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
build:sycl_nodouble --config=sycl
build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
build:sycl_nodouble --config=sycl
build:sycl_trisycl --define=using_trisycl=true
# Options extracted from configure script # Options extracted from configure script
build:ngraph --define=with_ngraph_support=true build:ngraph --define=with_ngraph_support=true
build:numa --define=with_numa_support=true build:numa --define=with_numa_support=true
@ -293,6 +281,7 @@ build:ios --noenable_platform_specific_config
build:android --copt=-w build:android --copt=-w
build:ios --copt=-w build:ios --copt=-w
build:linux --copt=-w build:linux --copt=-w
build:linux --host_copt=-w
build:macos --copt=-w build:macos --copt=-w
build:windows --copt=/w build:windows --copt=/w
@ -334,6 +323,11 @@ build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
build:windows --copt=-DNOGDI build:windows --copt=-DNOGDI
build:windows --host_copt=-DNOGDI build:windows --host_copt=-DNOGDI
# MSVC (Windows): Standards-conformant preprocessor mode
# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview
build:windows --copt=/experimental:preprocessor
build:windows --host_copt=/experimental:preprocessor
# Misc build options we need for windows. # Misc build options we need for windows.
build:windows --linkopt=/DEBUG build:windows --linkopt=/DEBUG
build:windows --host_linkopt=/DEBUG build:windows --host_linkopt=/DEBUG
@ -358,6 +352,7 @@ build --config=short_logs
# TODO(gunan): Create a feature in toolchains for avx/avx2 to # TODO(gunan): Create a feature in toolchains for avx/avx2 to
# avoid having to define linux/win separately. # avoid having to define linux/win separately.
build:avx_linux --copt=-mavx build:avx_linux --copt=-mavx
build:avx_linux --host_copt=-mavx
build:avx2_linux --copt=-mavx2 build:avx2_linux --copt=-mavx2
build:native_arch_linux --copt=-march=native build:native_arch_linux --copt=-march=native
build:avx_win --copt=/arch=AVX build:avx_win --copt=/arch=AVX
@ -411,9 +406,12 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds. # TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt build:rbe_linux --linkopt=-lrt
build:rbe_linux --host_linkopt=-lrt
build:rbe_linux --linkopt=-lm build:rbe_linux --linkopt=-lm
build:rbe_linux --host_linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --host_crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8" build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
@ -431,6 +429,7 @@ test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
@ -447,6 +446,7 @@ build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
@ -587,7 +587,7 @@ build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
build:release_gpu_common --action_env=TF_CUDA_VERSION="11" build:release_gpu_common --action_env=TF_CUDA_VERSION="11"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8" build:release_gpu_common --action_env=TF_CUDNN_VERSION="8"
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1" build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70" build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt" build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
@ -603,3 +603,8 @@ build:release_windows_common --announce_rc
build:release_cpu_windows --config=release_windows_common build:release_cpu_windows --config=release_windows_common
build:release_gpu_windows --config=release_windows_common build:release_gpu_windows --config=release_windows_common
build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux
build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"

View File

@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees # A list of assignees
assignees: assignees:
@ -40,6 +34,22 @@ segfault_memory:
# assignees # assignees
filesystem_security_assignee: filesystem_security_assignee:
- mihaimaruseac - mihaimaruseac
tflite_micro_path:
- tensorflow/lite/micro
tflite_micro_comment: >
Thanks for contributing to TensorFlow Lite Micro.
To keep this process moving along, we'd like to make sure that you have completed the items on this list:
* Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md)
* Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
* Linked to the issue from the PR description
We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review.
# Cuda Comment # Cuda Comment
cuda_comment: > cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries: From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:

View File

@ -1,4 +1,3 @@
#!/bin/bash
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -12,14 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ============================================================================
set -e
set -x
source tensorflow/tools/ci_build/release/common.sh on:
workflow_dispatch: # Allow manual triggers
# Rename to tensorflow_cpu schedule:
for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*m-win_amd64.whl); do - cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
copy_to_new_project_name "${f}" tensorflow_cpu name: Set nightly branch to master HEAD
rm "${f}" jobs:
done master-to-nightly:
runs-on: ubuntu-latest
steps:
- uses: zofrex/mirror-branch@v1
name: Set nightly branch to master HEAD
with:
target-branch: 'nightly'

View File

@ -1,10 +0,0 @@
# TensorFlow Adopters
This page contains a list of people and organizations who are using TensorFlow. If you'd like to be included
here, please send a pull request which modifies this file.
We intend to use this list to contact you for surveys, and to find good candidates for invite-only events.
We will also point to this list if we are asked who uses TensorFlow.
We will not use any of the information here for promotions or to send other regular communications. You
should subscribe to discuss@tensorflow.org for such announcements.

View File

@ -1,16 +1,15 @@
# Where component owners are known, add them here. # Where component owners are known, add them here.
/tensorflow/c/eager @jaingaurav @alextp /tensorflow/c/eager @qqfish @kkimdev
/tensorflow/core/common_runtime/eager @jaingaurav @alextp /tensorflow/core/common_runtime/eager @qqfish @kkimdev
/tenosrflow/core/debug @caisq /tenosrflow/core/debug @caisq
/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/windows/ @mihaimaruseac
/tensorflow/lite/experimental/micro @petewarden @advaitjain /tensorflow/lite/experimental/micro @petewarden @advaitjain
/tensorflow/python/autograph/ @mdanatg @kkimdev /tensorflow/python/autograph/ @mdanatg @kkimdev
/tensorflow/python/debug @caisq /tensorflow/python/debug @caisq
/tensorflow/python/eager @jaingaurav @alextp /tensorflow/python/eager @rohan100jain @kkimdev
/tensorflow/python/tools/api/generator/ @annarev /tensorflow/python/tools/api/generator/ @annarev
/tensorflow/tensorboard/ @jart
/tensorflow/tools/docs/ @markdaoust /tensorflow/tools/docs/ @markdaoust
/third_party/systemlibs/ @perfinion /third_party/systemlibs/ @perfinion

View File

@ -103,23 +103,22 @@ open-source software development:
### Official Builds ### Official Builds
Build Type | Status | Artifacts Build Type | Status | Artifacts
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- ----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) **macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
### Community Supported Builds ### Community Supported Builds
@ -145,19 +144,20 @@ Build Type
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) * [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official) * [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow Examples](https://github.com/tensorflow/examples) * [TensorFlow Examples](https://github.com/tensorflow/examples)
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice) * [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment) * [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2) * [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) * [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the * [TensorFlow Chat Room on StackOverflow (not actively monitored by the
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow) TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
* [TensorFlow Blog](https://blog.tensorflow.org) * [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml) * [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow) * [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) * [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) * [TensorFlow Roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard) * [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)

View File

@ -34,9 +34,33 @@
shape assumptions (note that you can pass shapes with `None` entries for axes shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`. entirely by setting `model.input_spec = None`.
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use * XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed). removed).
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
well defined for complex types.
* `tf.data.experimental.service.DispatchServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.DispatchServer(dispatcher_config)`.
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.WorkerServer(worker_config)`.
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
updates the gradient definition for quantization which is outside the range
to be 0. To simulate the V1 the behavior of
tf.quantization.quantize_and_dequantize(...) use
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
use `tf.data.Dataset.from_tensor_slices` instead.
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
`tf.distribute.StrategyExtended.batch_reduce_to`,
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
`tf.distribute.experimental.CollectiveHints` is renamed
`tf.distribute.experimental.CommunicationOptions`.
`tf.distribute.experimental.CollectiveCommunication` is renamed
`tf.distribute.experimental.CommunicationImplementation`.
## Known Caveats ## Known Caveats
@ -46,89 +70,180 @@
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX> * <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER> * <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See tensorflow/python/ops/numpy_ops/README.md for details of what are supported and what are the differences with NumPy. * A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. * A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
* `tf.distribute`:
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES> * <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE> * <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA> * <NOTES SHOULD BE GROUPED PER AREA>
* TF Core: * Security:
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
type annotation for variables representing a Tensor or a value that can be ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
converted to Tensor by `tf.convert_to_tensor`. * Fixes three vulnerabilities in conversion to DLPack format
* Calling ops with a python constants or numpy values is now consistent with ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
tf.convert_to_tensor behavior. This avoids operations like tf.reshape [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
truncating inputs such as from int64 to int32. [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments. * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
and `__invert__` now support non-`bool` arguments and apply the [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
corresponding bitwise ops. `bool` arguments continue to be supported and * Fixes several vulnerabilities in `RaggedCountSparseOutput` and
dispatch to logical ops. This brings them more in line with Python and NumPy `SparseCountSparseOutput` operations
benavior. ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
the same sparsity pattern, but with new provided values. It is similar to [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
the `with_values` function of `RaggedTensor`. [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops. [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
* `tf.data`: [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
* Added new `tf.data.experimental.service.register_dataset` and * Fixes an integer truncation vulnerability in code using the work sharder
`tf.data.experimental.service.from_dataset_id` APIs to enable one process API
to register a dataset with the tf.data service, and another process to ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
consume data from the dataset. * Fixes a format string vulnerability in `tf.strings.as_string`
* Added support for tf.data service dispatcher fault tolerance. To enable ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
fault tolerance, configure a `work_dir` when running your dispatcher * Fixes segfault raised by calling session-only ops in eager mode
server and set `dispatcher_fault_tolerance=True`. The dispatcher will ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
store its state to `work_dir`, so that on restart it can continue from its * Fixes data leak and potential ASLR violation from
previous state after restart. `tf.raw_ops.StringNGrams`
* Added tf.data service support for sharing dataset graphs via shared ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
filesystem instead of over RPC. This reduces load on the dispatcher, * Fixes segfaults caused by incomplete `SavedModel` validation
improving performance of distributing datasets. For this to work, the ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
dispatcher's `work_dir` must be accessible from workers. If the worker * Fixes a data corruption due to a bug in negative indexing support in
fails to read from the `work_dir`, it falls back to using RPC for dataset TFLite
graph transfer. ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is * Fixes a data corruption due to dimension mismatch in TFLite
the complement of `select_cols`; at most one of these should be specified. ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* We have implemented an optimization which reorders data-discarding * Fixes several vulnerabilities in TFLite saved model format
transformations such as `take` and `shard` to happen earlier in the ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
dataset when it is safe to do so. The optimization can be disabled via [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
the `experimental_optimization.reorder_data_discarding_ops` dataset [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
option. * Fixes several vulnerabilities in TFLite implementation of segment sum
* `tf.data.Options` were previously immutable and can now be overriden. ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
* `tf.image`: [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
* Added deterministic `tf.image.stateless_random_*` functions for each [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
`tf.image.random_*` function. Added a new op * TF Core:
`stateless_sample_distorted_bounding_box` which is a determinstic * `tf.types.experimental.TensorLike` is a new `Union` type that can be
version of `sample_distorted_bounding_box` op. Given the same seed, these used as type annotation for variables representing a Tensor or a value
stateless functions/ops produce the same results independent of how many that can be converted to Tensor by `tf.convert_to_tensor`.
times the function is called, and independent of global seed settings. * Calling ops with a python constants or numpy values is now consistent
with tf.convert_to_tensor behavior. This avoids operations like
tf.reshape truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of
`SparseTensor` arguments.
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
`__xor__` and `__invert__` now support non-`bool` arguments and apply
the corresponding bitwise ops. `bool` arguments continue to be supported
and dispatch to logical ops. This brings them more in line with Python
and NumPy behavior.
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
with the same sparsity pattern, but with new provided values. It is
similar to the `with_values` function of `RaggedTensor`.
* Added `StatelessCase` op, and uses it if none of case branches has
stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory
usage of the device.
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* `tf.data`:
* tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one
process to register a dataset with the tf.data service, and another
process to consume data from the dataset.
* Added support for dispatcher fault tolerance. To enable fault tolerance,
configure a `work_dir` when running your dispatcher server and set
`dispatcher_fault_tolerance=True`. The dispatcher will store its state
to `work_dir`, so that on restart it can continue from its previous
state after restart.
* Added support for sharing dataset graphs via shared filesystem instead
of over RPC. This reduces load on the dispatcher, improving performance
of distributing datasets. For this to work, the dispatcher's `work_dir`
must be accessible from workers. If the worker fails to read from the
`work_dir`, it falls back to using RPC for dataset graph transfer.
* Added support for a new "distributed_epoch" processing mode. This
processing mode distributes a dataset across all tf.data workers,
instead of having each worker process the full dataset. See
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
to learn more.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be
specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overridden.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
with a new `output_signature` argument, which allows `from_generator` to
produce any type describable by a `tf.TypeSpec`.
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
`tf.data.AUTOTUNE`.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a deterministic
version of `sample_distorted_bounding_box` op. Given the same seed,
these stateless functions/ops produce the same results independent of
how many times the function is called, and independent of global seed
settings.
* `tf.distribute`: * `tf.distribute`:
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* `tf.keras`: * `tf.keras`:
* Improvements from the functional API refactoring: * Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models. * Functional model construction does not need to maintain a global
* Functional model construction should be ~8-10% faster on average. workspace graph, removing memory leaks especially when building many
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument. models or very large models.
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale` * Functional model construction should be ~8-10% faster on average.
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. * Functional models can now contain non-symbolic values in their call
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` inputs inside of the first positional argument.
as an alternative to accepting a `callable` loss. * Several classes of TF ops that were not reliably converted to Keras
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) layers during functional API construction should now work, e.g.
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf). `tf.image.ssim_multiscale`
* Added `mobilenet_v3` to keras application model. * Error messages when Functional API construction goes wrong (and when
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for ops cannot be converted to Keras layers automatically) should be
customization of how gradients are aggregated across devices, as well as clearer and easier to understand.
`gradients_transformers` to allow for custom gradient transformations * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
(such as gradient clipping). as an alternative to accepting a `callable` loss.
* `tf.function` / AutoGraph: * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
* Added `experimental_follow_type_hints` argument for `tf.function`. When to match FTRL paper
True, the function may use type annotations to optimize the tracing (https://research.google.com/pubs/archive/41159.pdf).
performance. * Added `mobilenet_v3` to keras application model.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if customization of how gradients are aggregated across devices, as well as
the values of these symbols at an iteration does not depend on the previous `gradients_transformers` to allow for custom gradient transformations
iteration. These types of loops must run at least one iteration, and will (such as gradient clipping).
raise a runtime error otherwise. * The `steps_per_execution` argument in `compile()` is no longer
experimental; if you were passing `experimental_steps_per_execution`,
rename it to `steps_per_execution` in your code. This argument controls
the number of batches to run during each `tf.function` call when calling
`fit()`. Running multiple batches inside a single `tf.function` call can
greatly improve performance on TPUs or small models with a large Python
overhead.
* Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an
init arg.
* Normalization can now accept mean and variance values as init args.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
accepts a `return_attention_scores` argument. When set to
True, the layer returns the attention scores as an additional output
argument.
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
with the same implementation as their `tf.losses` equivalent.
* For Keras model, the individual call of `Model.evaluate` uses no cached
data for evaluation, while `Model.fit` uses cached data when
`validation_data` arg is provided for better performance.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the
previous iteration. These types of loops must run at least one
iteration, and will raise a runtime error otherwise.
Example: Example:
@ -137,45 +252,103 @@
outputs = train_step(batch) outputs = train_step(batch)
tf.print('final outputs', outputs) tf.print('final outputs', outputs)
``` ```
See tensorflow/python/autograph/g3doc/reference/limitations.md for more See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info. info.
* `tf.lite`: * `tf.lite`:
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty. * `TFLiteConverter`:
* `TFLiteConverter`: * Support optional flags `inference_input_type` and
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`). `inference_output_type` for full integer quantized models. This
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API allows users to modify the model input and output type to integer
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly. types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair. (`tf.float32`).
* <ADD RELEASE NOTES HERE> * TFLite Profiler for Android is available. See the detailed
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
* NNAPI
* Added NNAPI Delegation support for requantization use cases by
converting the operation into a dequantize-quantize pair.
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
* Use `Interpreter.Options.setUseNNAPI` instead.
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
* Prefer controlling this via delegate options, e.g.
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* <ADD RELEASE NOTES HERE>
* `tf.random`: * `tf.random`:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra: * Math and Linear Algebra:
* <ADD RELEASE NOTES HERE>
* Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
* TPU Enhancements: * TPU Enhancements:
* Added support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent * Added support for the `beta` parameter of the FTRL optimizer for TPU
behavior by adjusting the `l2` parameter. embeddings. Users of other TensorFlow platforms can implement equivalent
* <ADD RELEASE NOTES HERE> behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* XLA Support: * XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead * xla.experimental.compile is deprecated, use
* <ADD RELEASE NOTES HERE> `tf.function(experimental_compile=True)` instead
* Added `tf.function.experimental_get_compiler_ir` which returns compiler
IR (currently 'hlo' and 'optimized_hlo') for given input for given
function.
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging: * Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`: * `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint` * Now accepts a `root` argument in the initialization, which generates a
object that is compatible with Keras `model.save_weights()` and checkpoint with a root object. This allows users to create a
`model.load_weights`. The checkpoint is also compatible with the `Checkpoint` object that is compatible with Keras `model.save_weights()`
checkpoint saved in the `variables/` folder in the SavedModel. and `model.load_weights`. The checkpoint is also compatible with the
* When restoring, `save_path` can be a path to a SavedModel. The function checkpoint saved in the `variables/` folder in the SavedModel.
will automatically find the checkpoint in the SavedModel. * When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* `tf.debugging`:
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
* `tf.print`:
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping.
* `TensorRT`
* We now issue a warning when the `session_config` parameter for the TF1
converter is used or the `rewrite_config_template` field in the TF2
converter parameter object is used.
* Other: * Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see * We have replaced uses of "whitelist" and "blacklist" with "allowlist"
https://developers.google.com/style/word-list#blacklist for more context. and "denylist" where possible. Please see
* <ADD RELEASE NOTES HERE> https://developers.google.com/style/word-list#blacklist for more
context.
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
rollout the new MLIR TPU bridge.
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors ## Thanks to our Contributors
@ -183,45 +356,327 @@ This release contains contributions from many people at Google, as well as:
stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE> stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.1
## Bug Fixes and Other Changes
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
`SparseCountSparseOutput` operations
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
* Fixes an integer truncation vulnerability in code using the work sharder API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Fixes several vulnerabilities in TFLite implementation of segment sum
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
* Updates `sqlite3` to `3.33.00` to handle
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
* Fixes deprecated usage of `collections` API
* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it
to install the pip package
# Release 2.2.1
## Bug Fixes and Other Changes
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes an integer truncation vulnerability in code using the work sharder API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Fixes several vulnerabilities in TFLite implementation of segment sum
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
* Updates `sqlite3` to `3.33.00` to handle
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
and
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
* Fixes deprecated usage of `collections` API
* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it
to install the pip package
# Release 2.1.2
## Bug Fixes and Other Changes
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes an integer truncation vulnerability in code using the work sharder API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Updates `sqlite3` to `3.33.00` to handle
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
and
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it
to install the pip package
* Switches ROCM builds to use ROCM 3.7
# Release 2.0.3
## Bug Fixes and Other Changes
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes an integer truncation vulnerability in code using the work sharder API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Updates `sqlite3` to `3.33.00` to handle
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
and
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
* Pins `numpy` to 1.18.5 to prevent ABI breakage when compiling code that uses
both NumPy and TensorFlow headers.
# Release 1.15.4
## Bug Fixes and Other Changes
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes an integer truncation vulnerability in code using the work sharder API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Updates `sqlite3` to `3.33.00` to handle
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
and
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
* Fixes #41630 by including `max_seq_length` in CuDNN descriptor cache key
* Pins `numpy` to 1.18.5 to prevent ABI breakage when compiling code that uses
both NumPy and TensorFlow headers.
# Release 2.3.0 # Release 2.3.0
## Major Features and Improvements ## Major Features and Improvements
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler. * `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and
save resources:
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`). * [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your models memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level. In addition checkout the detailed
[guide](https://www.tensorflow.org/guide/data_performance_analysis) for
analyzing input pipeline performance with TF Profiler.
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers. * [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
is now a stable API and no longer considered experimental for TensorFlow.
(earlier `tf.distribute.experimental.TPUStrategy`).
* TFLite now properly supports dynamic shapes during conversion and inference. Weve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental). * [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new
tools: a memory profiler to visualize your models memory usage over time
and a [python tracer](https://www.tensorflow.org/guide/profiler#events)
which allows you to trace python function calls in your model. Usability
improvements include better diagnostic messages and
[profile options](https://tensorflow.org/guide/profiler#collect_performance_data)
to customize the host and device trace verbosity level.
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds). * Introduces experimental support for Keras Preprocessing Layers API
([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly))
to handle data preprocessing operations, with support for composite tensor
inputs. Please see below for additional details on these layers.
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations. * TFLite now properly supports dynamic shapes during conversion and inference.
Weve also added opt-in support on Android and iOS for
[XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack),
a highly optimized set of CPU kernels, as well as opt-in support for
[executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
* Libtensorflow packages are available in GCS starting this release. We have
also started to
[release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
* The experimental Python API
[`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info)
now allows you to instrument a TensorFlow program and dump debugging
information to a directory on the file system. The directory can be read and
visualized by a new interactive dashboard in TensorBoard 2.3 called
[Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which
reveals the details of the TensorFlow program including graph structures,
history of op executions at the Python (eager) and intra-graph levels, the
runtime dtype, shape, and numerical composition of tensors, as well as their
code locations.
## Breaking Changes ## Breaking Changes
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
* `tf.data` * Increases the **minimum bazel version** required to build TF to **3.1.0**.
* Makes the following (breaking) changes to the `tf.data`. * `tf.data`
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation. * Makes the following (breaking) changes to the `tf.data`.
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`. * C++ API: - `IteratorBase::RestoreInternal`,
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed. `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState`
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly. become pure-virtual and subclasses are now expected to provide an
* `tf.keras` implementation.
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback. * The deprecated `DatasetBase::IsStateful` method is removed in favor of
* `tf.image.extract_glimpse` has been updated to correctly process the case `DatasetBase::CheckExternalState`.
where `centered=False` and `normalized=False`. This is a breaking change as * Deprecated overrides of `DatasetBase::MakeIterator` and
the output is different from (incorrect) previous versions. Note this `MakeIteratorFromInputElement` are removed.
breaking change only impacts `tf.image.extract_glimpse` and * The signature of `tensorflow::data::IteratorBase::SaveInternal` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of `tensorflow::data::IteratorBase::SaveInput` has been extended with
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of `SerializationContext` argument to enable overriding the default policy
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved for the handling external state during iterator checkpointing. This is
models using `tf.raw_ops.ExtractGlimpse` will not be impacted. not a backwards compatible change and all subclasses of `IteratorBase`
*need to be updated* accordingly.
* `tf.keras`
* Add a new `BackupAndRestore` callback for handling distributed training
failures & restarts. Please take a look at this
[tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
for details on how to use the callback.
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
existing C++ kernel `ExtractGlimpse` does not change either, so saved models
using `tf.raw_ops.ExtractGlimpse` will not be impacted.
## Known Caveats ## Known Caveats
* `tf.lite` * `tf.lite`
@ -791,7 +1246,7 @@ This release contains contributions from many people at Google, as well as:
8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee 8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee
# Release 1.15.0 # Release 1.15.0
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year. This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
## Major Features and Improvements ## Major Features and Improvements
* As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size. * As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
@ -801,7 +1256,7 @@ This enables writing forward compatible code: by explicitly importing either `te
* Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow. * Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow.
* Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`. * Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`.
* AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS. * AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS.
* Adds `enable_tensor_equality()`, which switches the behavior such that: * Adds `enable_tensor_equality()`, which switches the behavior such that:
* Tensors are no longer hashable. * Tensors are no longer hashable.
* Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0. * Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0.
@ -957,12 +1412,12 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
* TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow. * TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow.
* Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation. * Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation.
Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs. Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs.
* `tf.contrib`: * `tf.contrib`:
* `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely. * `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely.
* Remove `tf.contrib.timeseries` dependency on TF distributions. * Remove `tf.contrib.timeseries` dependency on TF distributions.
* Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`. * Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`.
* `tf.estimator`: * `tf.estimator`:
* Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`. * Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`.
* Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method. * Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method.
@ -970,13 +1425,13 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
* `Estimator.export_savedmodel` has been renamed to `export_saved_model`. * `Estimator.export_savedmodel` has been renamed to `export_saved_model`.
* When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`. * When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`.
* Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead. * Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead.
* `tf.keras`: * `tf.keras`:
* `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs. * `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs.
* `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported. * `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported.
* Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead. * Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead.
* Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer <layer-name>` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information. * Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer <layer-name>` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information.
* `tf.lite`: * `tf.lite`:
* Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API. * Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API.
* Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior. * Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior.
@ -1211,8 +1666,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
conversion. TensorRT initialization arguments are now passed wrapped in conversion. TensorRT initialization arguments are now passed wrapped in
a named-tuple, `TrtConversionParams`, rather than as separate arguments a named-tuple, `TrtConversionParams`, rather than as separate arguments
as in `TrtGraphConverter`. as in `TrtGraphConverter`.
* Changed API to optimize TensorRT enginges during graph optimization. * Changed API to optimize TensorRT engines during graph optimization. This
This is now done by calling `converter.build()` where previously is now done by calling `converter.build()` where previously
`is_dynamic_op=False` would be set. `is_dynamic_op=False` would be set.
* `converter.convert()` no longer returns a `tf.function`. Now the * `converter.convert()` no longer returns a `tf.function`. Now the
function must be accessed from the saved model. function must be accessed from the saved model.
@ -2222,7 +2677,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
* [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier) * [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier)
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector) * The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
API supports broadcasting for Bijectors with new API changes. API supports broadcasting for Bijectors with new API changes.
## Breaking Changes ## Breaking Changes
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by * If you're opening empty variable scopes; replace `variable_scope('', ...)` by
`variable_scope(tf.get_variable_scope(), ...)`. `variable_scope(tf.get_variable_scope(), ...)`.
@ -2701,7 +3156,7 @@ Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim
Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan, Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan,
Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay, Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay,
Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang, Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang,
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武 Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
We are also grateful to all who filed issues or helped resolve them, asked and We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions. answered questions, and were part of inspiring discussions.

View File

@ -38,9 +38,6 @@ _DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_TENSORRT_VERSION = '6' _DEFAULT_TENSORRT_VERSION = '6'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18] _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
@ -1114,62 +1111,6 @@ def set_host_c_compiler(environ_cp):
write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
def set_computecpp_toolkit_path(environ_cp):
"""Set COMPUTECPP_TOOLKIT_PATH."""
def toolkit_exists(toolkit_path):
"""Check if a computecpp toolkit path is valid."""
if is_linux():
sycl_rt_lib_path = 'lib/libComputeCpp.so'
else:
sycl_rt_lib_path = ''
sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path)
exists = os.path.exists(sycl_rt_lib_path_full)
if not exists:
print('Invalid SYCL %s library path. %s cannot be found' %
(_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
return exists
computecpp_toolkit_path = prompt_loop_or_load_from_env(
environ_cp,
var_name='COMPUTECPP_TOOLKIT_PATH',
var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH,
ask_for_var=(
'Please specify the location where ComputeCpp for SYCL %s is '
'installed.' % _TF_OPENCL_VERSION),
check_success=toolkit_exists,
error_msg='Invalid SYCL compiler path. %s cannot be found.',
suppress_default_error=True)
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
computecpp_toolkit_path)
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: ') % (
_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break
print('Invalid triSYCL include directory, %s cannot be found' %
(trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def system_specific_test_config(environ_cp): def system_specific_test_config(environ_cp):
"""Add default build and test flags required for TF tests to bazelrc.""" """Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3') write_to_bazelrc('test --flaky_test_attempts=3')
@ -1397,8 +1338,6 @@ def main():
setup_python(environ_cp) setup_python(environ_cp)
if is_windows(): if is_windows():
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_CUDA_CLANG'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0'
@ -1415,21 +1354,6 @@ def main():
if environ_cp.get('TF_ENABLE_XLA', '1') == '1': if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla') write_to_bazelrc('build --config=xla')
set_action_env_var(
environ_cp,
'TF_NEED_OPENCL_SYCL',
'OpenCL SYCL',
False,
bazel_config_name='sycl')
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
set_host_cxx_compiler(environ_cp)
set_host_c_compiler(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True)
if environ_cp.get('TF_NEED_COMPUTECPP') == '1':
set_computecpp_toolkit_path(environ_cp)
else:
set_trisycl_include_dir(environ_cp)
set_action_env_var( set_action_env_var(
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm') environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
if (environ_cp.get('TF_NEED_ROCM') == '1' and if (environ_cp.get('TF_NEED_ROCM') == '1' and
@ -1442,6 +1366,11 @@ def main():
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH')) write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
if ((environ_cp.get('TF_NEED_ROCM') == '1') and
(environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')):
write_to_bazelrc(
'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1')
environ_cp['TF_NEED_CUDA'] = str( environ_cp['TF_NEED_CUDA'] = str(
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
if (environ_cp.get('TF_NEED_CUDA') == '1' and if (environ_cp.get('TF_NEED_CUDA') == '1' and
@ -1523,17 +1452,15 @@ def main():
# use it for the CPU build. # use it for the CPU build.
set_tf_download_clang(environ_cp) set_tf_download_clang(environ_cp)
# SYCL / ROCm / CUDA are mutually exclusive. # ROCm / CUDA are mutually exclusive.
# At most 1 GPU platform can be configured. # At most 1 GPU platform can be configured.
gpu_platform_count = 0 gpu_platform_count = 0
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
gpu_platform_count += 1
if environ_cp.get('TF_NEED_ROCM') == '1': if environ_cp.get('TF_NEED_ROCM') == '1':
gpu_platform_count += 1 gpu_platform_count += 1
if environ_cp.get('TF_NEED_CUDA') == '1': if environ_cp.get('TF_NEED_CUDA') == '1':
gpu_platform_count += 1 gpu_platform_count += 1
if gpu_platform_count >= 2: if gpu_platform_count >= 2:
raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. ' raise UserInputError('CUDA / ROCm are mututally exclusive. '
'At most 1 GPU platform can be configured.') 'At most 1 GPU platform can be configured.')
set_cc_opt_flags(environ_cp) set_cc_opt_flags(environ_cp)
@ -1558,6 +1485,7 @@ def main():
'adding "--config=<>" to your build command. See .bazelrc for more ' 'adding "--config=<>" to your build command. See .bazelrc for more '
'details.') 'details.')
config_info_line('mkl', 'Build with MKL support.') config_info_line('mkl', 'Build with MKL support.')
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
config_info_line('monolithic', 'Config for mostly static monolithic build.') config_info_line('monolithic', 'Config for mostly static monolithic build.')
config_info_line('ngraph', 'Build with Intel nGraph support.') config_info_line('ngraph', 'Build with Intel nGraph support.')
config_info_line('numa', 'Build with NUMA support.') config_info_line('numa', 'Build with NUMA support.')

View File

@ -497,13 +497,20 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# This flag enables experimental MLIR bridge support. # This flag forcibly enables experimental MLIR bridge support.
config_setting( config_setting(
name = "enable_mlir_bridge", name = "enable_mlir_bridge",
values = {"define": "enable_mlir_bridge=true"}, values = {"define": "enable_mlir_bridge=true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# This flag forcibly disables experimental MLIR bridge support.
config_setting(
name = "disable_mlir_bridge",
values = {"define": "enable_mlir_bridge=false"},
visibility = ["//visibility:public"],
)
# This flag enables experimental TPU support # This flag enables experimental TPU support
config_setting( config_setting(
name = "with_tpu_support", name = "with_tpu_support",
@ -562,33 +569,17 @@ selects.config_setting_group(
package_group( package_group(
name = "internal", name = "internal",
packages = [ packages = [
"//learning/brain/swift/x10/...", "//learning/lib/ami/simple_ml/...",
"//perftools/accelerators/xprof/api/...",
"//tensorflow/...", "//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
], ],
) )
package_group( package_group(name = "ndarray_tensor_allow_list")
name = "ndarray_tensor_allow_list",
packages = ["//learning/pathways/..."],
)
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported. # Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove. # TODO(b/154650521) Remove.
package_group( # If this is modified, then copy.bara.sky must also be modified.
name = "types_whitelist", package_group(name = "types_whitelist")
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
# Packages that use StructuredTensors. # Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported. # TODO(b/159007891) Remove this package once StructuredTensor is exported.
@ -714,8 +705,12 @@ tf_cc_shared_object(
soversion = VERSION, soversion = VERSION,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/c:kernels_hdrs",
"//tensorflow/c:ops_hdrs",
"//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core:core_cpu_impl", "//tensorflow/core/common_runtime:core_cpu_impl",
"//tensorflow/core:framework_internal_impl", "//tensorflow/core:framework_internal_impl",
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",

View File

@ -138,12 +138,12 @@ if _running_from_pip_package():
for _s in _site_packages_dirs: for _s in _site_packages_dirs:
# Load first party dynamic kernels. # Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels') _main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir): if _os.path.exists(_main_dir):
_ll.load_library(_main_dir) _ll.load_library(_main_dir)
# Load third party dynamic kernels. # Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins') _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir): if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir) _ll.load_library(_plugin_dir)
# Add module aliases # Add module aliases

View File

@ -148,12 +148,12 @@ if _running_from_pip_package():
for _s in _site_packages_dirs: for _s in _site_packages_dirs:
# Load first party dynamic kernels. # Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels') _main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir): if _os.path.exists(_main_dir):
_ll.load_library(_main_dir) _ll.load_library(_main_dir)
# Load third party dynamic kernels. # Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins') _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir): if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir) _ll.load_library(_plugin_dir)
# Delete modules that should be hidden from dir(). # Delete modules that should be hidden from dir().

View File

@ -1,6 +1,7 @@
# Description: # Description:
# C API for TensorFlow, for use by client language bindings. # C API for TensorFlow, for use by client language bindings.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"tf_cc_test", "tf_cc_test",
@ -9,6 +10,11 @@ load(
"tf_custom_op_library", "tf_custom_op_library",
"tf_kernel_library", "tf_kernel_library",
) )
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
package( package(
@ -211,6 +217,8 @@ tf_cuda_library(
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:logging_ops", "//tensorflow/core/kernels:logging_ops",
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
], ],
}), }),
alwayslink = 1, alwayslink = 1,
@ -248,6 +256,30 @@ tf_cuda_library(
}), }),
) )
cc_library(
name = "tf_shape",
srcs = ["tf_shape.cc"],
hdrs = ["tf_shape.h"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tf_shape_internal",
"//tensorflow/core:framework",
],
)
cc_library(
name = "tf_shape_internal",
hdrs = ["tf_shape_internal.h"],
copts = tf_copts(),
visibility = ["//tensorflow:internal"],
deps = [
":conversion_macros",
"//tensorflow/core:framework",
],
)
cc_library( cc_library(
name = "tf_status", name = "tf_status",
srcs = ["tf_status.cc"], srcs = ["tf_status.cc"],
@ -377,6 +409,7 @@ tf_cuda_library(
"//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:get_compiler_ir",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -387,6 +420,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform", "//tensorflow/core/platform",
"//tensorflow/core/platform:blocking_counter",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,
@ -477,6 +511,18 @@ tf_cuda_library(
], ],
) )
cc_library(
name = "kernels_hdrs",
hdrs = ["kernels.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_internal",
":tf_datatype",
":tf_status",
":tf_tensor",
],
)
tf_cuda_library( tf_cuda_library(
name = "kernels", name = "kernels",
srcs = [ srcs = [
@ -530,6 +576,16 @@ tf_cuda_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "ops_hdrs",
hdrs = ["ops.h"],
visibility = ["//tensorflow:internal"],
deps = [
":tf_datatype",
":tf_status",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Tests # Tests

View File

@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
return ret; return ret;
} }
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
using tensorflow::RecordMutation;
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
}
// TF_Server functions ---------------------------------------------- // TF_Server functions ----------------------------------------------
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)

View File

@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
const char* name, TF_Status* status); const char* name, TF_Status* status);
// Update edge, switch input/ output in a node
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
TF_Input dst, TF_Status* status);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// In-process TensorFlow server functionality, for use in distributed training. // In-process TensorFlow server functionality, for use in distributed training.
// A Server instance encapsulates a set of devices and a Session target that // A Server instance encapsulates a set of devices and a Session target that

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/net.h"
@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
collective_executor_handle->get()->StartAbort(status->status); collective_executor_handle->get()->StartAbort(status->status);
} }
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, [&done, status](const Status& s) {
status->status = s;
done.Notify();
});
done.WaitForNotification();
}
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
result->num_items = num_items; result->num_items = num_items;

View File

@ -231,13 +231,20 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
TF_Status* status); TF_Status* status);
// Aborts all ongoing collectives with the specified status. After abortion, // Aborts all ongoing collectives with the specified status. After abortion,
// subsequent collectives will error with this status immediately. // subsequent collectives will error with this status immediately. To reset the
// collectives, create a new EagerContext.
// //
// This is intended to be used when a peer failure is detected. There's yet no // This is intended to be used when a peer failure is detected.
// way to reset the collectives other than restarting the program.
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status); TF_Status* status);
// Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status);
// Information about the shape of a Tensor and its type. // Information about the shape of a Tensor and its type.
struct TF_ShapeAndType { struct TF_ShapeAndType {
// Number of dimensions. -1 indicates unknown rank. // Number of dimensions. -1 indicates unknown rank.

View File

@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1); TF_DeleteFunction(func1);
} }
// This test only works when the TF build includes XLA compiler. One way to set
// this up is via bazel build option "--define with_xla_support=true".
//
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
// something like TENSORFLOW_CAPI_USE_XLA.
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
TF_Function* func;
const std::string funcName = "BranchFunc";
DefineFunction(funcName.c_str(), &func);
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* feed = Placeholder(host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_OperationDescription* desc =
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
TF_AddInput(desc, {true_cond, 0});
TF_Output inputs[] = {{feed, 0}};
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_SetAttrType(desc, "Tcond", TF_BOOL);
TF_DataType inputType = TF_INT32;
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
TF_SetDevice(desc, "/device:XLA_CPU:0");
auto op = TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(op, nullptr);
// Create a session for this graph.
CSession csession(host_graph_, s_, /*use_XLA*/ true);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Run the graph.
csession.SetInputs({{feed, Int32Tensor(17)}});
csession.SetOutputs({op});
csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_INT32, TF_TensorType(out));
EXPECT_EQ(0, TF_NumDims(out)); // scalar
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
EXPECT_EQ(-17, *output_contents);
// Clean up
csession.CloseAndDelete(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_DeleteFunction(func);
}
#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
TF_DeleteStatus(s); TF_DeleteStatus(s);
} }
TEST(CAPI, UpdateEdge) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Make two scalar constants.
TF_Operation* one = ScalarConst(1, graph, s, "one");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* two = ScalarConst(2, graph, s, "two");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add oper.
TF_Operation* add = Add(one, two, graph, s, "add");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add another oper to the graph.
TF_Operation* neg = Neg(add, graph, s, "neg");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
NodeDef node_def_neg;
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("add"), node_def_neg.input(0));
// update edge of neg
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
// Clean up
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
/* /*
TODO(skyewm): this test currently DCHECKs, change to bad status TODO(skyewm): this test currently DCHECKs, change to bad status

View File

@ -1,13 +1,23 @@
# Experimental extensions to the C API for eager execution of kernels. # Experimental extensions to the C API for eager execution of kernels.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"if_libtpu",
"tf_cc_test", "tf_cc_test",
"tf_copts", "tf_copts",
"tf_cuda_cc_test", "tf_cuda_cc_test",
"tf_cuda_library", "tf_cuda_library",
"tfe_xla_copts",
) )
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "internal_tfrt_deps")
load( load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
"tf_kernel_tests_linkstatic", "tf_kernel_tests_linkstatic",
@ -31,7 +41,7 @@ tf_cuda_library(
"c_api_unified_experimental.h", "c_api_unified_experimental.h",
], ],
hdrs = ["c_api.h"], hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
@ -72,13 +82,6 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
], ],
}) + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/jit:xla_device",
],
"//conditions:default": [],
}) + [ }) + [
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:eager_operation",
@ -95,7 +98,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core:gpu_runtime", "//tensorflow/core:gpu_runtime",
], ] + internal_tfrt_deps(),
alwayslink = 1, alwayslink = 1,
) )
@ -109,11 +112,16 @@ filegroup(
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"c_api_unified_experimental.h", "c_api_unified_experimental.h",
"c_api_unified_experimental_internal.h",
"dlpack.h", "dlpack.h",
"gradients.h",
"gradients_internal.h",
"immediate_execution_context.h", "immediate_execution_context.h",
"immediate_execution_operation.h", "immediate_execution_operation.h",
"immediate_execution_tensor_handle.h", "immediate_execution_tensor_handle.h",
"tape.h",
"tfe_cancellation_manager_internal.h", "tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
"tfe_executor_internal.h", "tfe_executor_internal.h",
"tfe_monitoring_internal.h", "tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h", "tfe_op_attrs_internal.h",
@ -172,27 +180,20 @@ cc_library(
) )
cc_library( cc_library(
name = "gradients", name = "tracing_utils",
srcs = [ srcs = ["tracing_utils.cc"],
"gradients.cc",
"gradients_internal.h",
],
hdrs = [ hdrs = [
"gradients.h", "tracing_utils.h",
], ],
visibility = [ visibility = [
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":abstract_context",
":abstract_operation", ":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal", ":c_api_unified_internal",
":tape", "//tensorflow/c/experimental/gradients/tape:tape_operation",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map", "//tensorflow/core/platform:errors",
"@com_google_absl//absl/strings",
], ],
) )
@ -228,10 +229,10 @@ tf_cuda_cc_test(
"gradients_test.cc", "gradients_test.cc",
], ],
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"], tags = tf_cuda_tests_tags() + ["nomac"],
deps = [ deps = [
":abstract_context",
":abstract_tensor_handle", ":abstract_tensor_handle",
":c_api_experimental", ":c_api_experimental",
":c_api_test_util", ":c_api_test_util",
@ -242,7 +243,8 @@ tf_cuda_cc_test(
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:array_grad", "//tensorflow/c/experimental/gradients:array_grad",
"//tensorflow/c/experimental/gradients:math_grad", "//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops",
"//tensorflow/cc/profiler", "//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -256,6 +258,46 @@ tf_cuda_cc_test(
], ],
) )
cc_library(
name = "gradients_util",
srcs = [
"gradients_util.cc",
],
hdrs = [
"gradients_util.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":tape",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
)
cc_library( cc_library(
name = "mnist_gradients_testutil", name = "mnist_gradients_testutil",
srcs = [ srcs = [
@ -272,17 +314,93 @@ cc_library(
":c_api_experimental", ":c_api_experimental",
":c_api_unified_internal", ":c_api_unified_internal",
":gradients_internal", ":gradients_internal",
"//tensorflow/c:tf_status_helper", ":gradients_util",
"//tensorflow/c:tf_tensor", ":tape",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
cc_library(
name = "gradient_checker",
srcs = [
"gradient_checker.cc",
],
hdrs = [
"gradient_checker.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
)
tf_cuda_cc_test(
name = "gradient_checker_test",
size = "small",
srcs = [
"gradient_checker_test.cc",
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradient_checker",
":gradients_internal",
":gradients_util",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test( tf_cuda_cc_test(
name = "mnist_gradients_test", name = "mnist_gradients_test",
size = "small", size = "small",
@ -290,19 +408,16 @@ tf_cuda_cc_test(
"mnist_gradients_test.cc", "mnist_gradients_test.cc",
], ],
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [ tags = tf_cuda_tests_tags() + [
"nomac", "nomac",
"notap", # TODO(b/166150182): Enable
"no_oss", # TODO(b/166150182): Enable
], ],
deps = [ deps = [
":abstract_tensor_handle", ":abstract_tensor_handle",
":c_api_experimental", ":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal", ":c_api_unified_internal",
":gradients_internal", ":gradients_internal",
":gradients_util",
":mnist_gradients_testutil", ":mnist_gradients_testutil",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
@ -526,6 +641,19 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "tfe_tensorhandle_internal_hdrs_only",
extra_deps = [
"@com_google_absl//absl/strings",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tfe_tensorhandle_internal",
],
)
tf_cuda_library( tf_cuda_library(
name = "c_api_test_util", name = "c_api_test_util",
testonly = 1, testonly = 1,
@ -539,6 +667,8 @@ tf_cuda_library(
":c_api", ":c_api",
":c_api_experimental", ":c_api_experimental",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -553,7 +683,6 @@ tf_cuda_cc_test(
"c_api_debug_test.cc", "c_api_debug_test.cc",
"c_api_test.cc", "c_api_test.cc",
], ],
extra_copts = tfe_xla_copts(),
tags = [ tags = [
"noguitar", # TODO(b/155445984): flaky "noguitar", # TODO(b/155445984): flaky
#"guitar", #"guitar",
@ -608,7 +737,6 @@ tf_cuda_cc_test(
], ],
# TODO(b/136478427): Figure out how to correctly shut the server down # TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [ tags = [
"no_windows", "no_windows",
], ],
@ -641,7 +769,6 @@ tf_cuda_cc_test(
], ],
# TODO(b/136478427): Figure out how to correctly shut the server down # TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [ tags = [
"no_windows", "no_windows",
], ],
@ -660,7 +787,6 @@ tf_cuda_cc_test(
], ],
# TODO(b/136478427): Figure out how to correctly shut the server down # TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [ tags = [
"no_windows", "no_windows",
"noasan", # leaks gRPC server instances "noasan", # leaks gRPC server instances
@ -694,7 +820,6 @@ tf_cuda_cc_test(
], ],
# TODO(b/136478427): Figure out how to correctly shut the server down # TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [ tags = [
"no_windows", "no_windows",
], ],
@ -729,7 +854,7 @@ tf_cuda_library(
"c_api_experimental.h", "c_api_experimental.h",
"c_api_unified_experimental.h", "c_api_unified_experimental.h",
], ],
copts = tf_copts() + tfe_xla_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
@ -801,7 +926,6 @@ tf_cuda_cc_test(
"c_api_experimental_test.cc", "c_api_experimental_test.cc",
], ],
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"], tags = tf_cuda_tests_tags() + ["nomac"],
deps = [ deps = [
@ -814,6 +938,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -825,7 +950,6 @@ tf_cuda_cc_test(
"c_api_unified_experimental_test.cc", "c_api_unified_experimental_test.cc",
], ],
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"], tags = tf_cuda_tests_tags() + ["nomac"],
deps = [ deps = [
@ -834,6 +958,7 @@ tf_cuda_cc_test(
":c_api_test_util", ":c_api_test_util",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler", "//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -943,7 +1068,13 @@ filegroup(
"c_api_unified_experimental_eager.cc", "c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc", "c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_internal.h", "c_api_unified_experimental_internal.h",
"gradient_checker.cc",
"gradient_checker.h",
"gradients.cc", # Uses RTTI. "gradients.cc", # Uses RTTI.
"gradients_util.cc",
"gradients_util.h",
"tracing_utils.h",
"tracing_utils.cc",
"*test*", "*test*",
"*dlpack*", "*dlpack*",
], ],

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// environment, a traced representation etc. // environment, a traced representation etc.
class AbstractContext { class AbstractContext {
protected: protected:
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt }; enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {} virtual ~AbstractContext() {}

View File

@ -30,7 +30,7 @@ namespace tensorflow {
// tracing or immediate execution mode. // tracing or immediate execution mode.
class AbstractOperation { class AbstractOperation {
protected: protected:
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt }; enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {} virtual ~AbstractOperation() {}

View File

@ -39,7 +39,7 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#ifdef PLATFORM_GOOGLE #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif #endif
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
@ -51,9 +51,6 @@ limitations under the License.
#include "tensorflow/core/protobuf/device_filters.pb.h" #include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
@ -629,21 +626,30 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
"targets will fail."; "targets will fail.";
} }
} else { } else {
// The master's context_view_id will be incremented by one if (sg.ok()) {
// the UpdateRemoteMaster call later. We want all new workers and // Create remote contexts on the newly added workers only if the master
// existing workers to also have the updated context_view_id, so // has collected all device information from them (i.e., the
// we must set their context_view_id to the existing master's // GetAllRemoteDevices call returns succussfully). Note that in rare cases
// context_view_id + 1. // GetAllRemoteDevices can still fail even with RPCs configured to wait
sg.Update(CreateRemoteContexts( // until the remote workers to become alive. If the master creates remote
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, // contexts on the workers whose devices are still not collected, those
server_def, remote_eager_workers.get(), context->Executor().Async(), // workers will be treated as existing workers subsequently, so the master
context->LazyCopyFunctionRemoteInputs(), base_request)); // will never get devices from them even with retrying UpdateServerDef.
sg.Update(CreateRemoteContexts(
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
}
if (!existing_workers.empty()) { if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) { for (const string& w : existing_workers) {
VLOG(1) << "Updating cluster with existing worker " << w; VLOG(1) << "Updating cluster with existing worker " << w;
} }
} }
// The master's context_view_id will be incremented by one in the
// UpdateRemoteMaster call later. We want existing workers to also have
// the updated context_view_id, so we must set their context_view_id to
// the master's current context_view_id + 1.
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
removed_workers, context_id, removed_workers, context_id,
context_view_id + 1, server_def, context_view_id + 1, server_def,
@ -723,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) { if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
#else #else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
@ -745,10 +751,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
opts->session_options.options, opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>( static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy), opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r, /*device_mgr_owned*/ true, r));
tensorflow::GetDefaultCustomKernelCreator()));
} }
void TFE_DeleteContext(TFE_Context* ctx) { void TFE_DeleteContext(TFE_Context* ctx) {
@ -851,20 +855,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(context->GetServer());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers);
if (!status->status.ok()) {
LOG(ERROR) << "Failed to get client cache for remote workers.";
return false;
}
// TODO(yuefengz): support partially specified `worker_name`. // TODO(yuefengz): support partially specified `worker_name`.
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client; tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
status->status = remote_eager_workers->GetClient(worker_name, &eager_client); status->status = context->GetClient(worker_name, &eager_client);
if (!status->status.ok()) { if (!status->status.ok()) {
return false; return false;
} }
@ -911,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy( void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
} }
@ -922,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads. // safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>( return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy()); tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
} }
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
@ -1149,26 +1138,23 @@ void TFE_DeleteOp(TFE_Op* op) {
tensorflow::unwrap(op)->Release(); tensorflow::unwrap(op)->Release();
} }
const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->Name().c_str();
}
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
return tensorflow::wrap(
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name); status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
} }
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->DeviceName().c_str(); return tensorflow::unwrap(op)->DeviceName().c_str();
} }
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
#else
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
#endif // TENSORFLOW_EAGER_USE_XLA
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input)); status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
} }
@ -1181,6 +1167,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
static_cast<size_t>(num_inputs)}); static_cast<size_t>(num_inputs)});
} }
extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->GetInputs().size();
}
extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
TF_Status* status) {
return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) { unsigned char* is_list, TF_Status* status) {
TF_AttrType ret = TF_ATTR_INT; TF_AttrType ret = TF_ATTR_INT;
@ -1430,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
} }
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context = return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr;
} }
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
} }
} // extern "C" } // extern "C"
@ -1486,7 +1475,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::unwrap(ctx)->EndStep(); tensorflow::unwrap(ctx)->EndStep();
} }
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) { const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
return tensorflow::wrap( return tensorflow::wrap(
&OperationFromInterface(tensorflow::unwrap(op))->Attrs()); &OperationFromInterface(tensorflow::unwrap(op))->Attrs());
} }
@ -1551,8 +1540,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
TFE_OpSetAttrFunction(op, attr_name, func_op); TFE_OpSetAttrFunction(op, attr_name, func_op);
TFE_DeleteOp(func_op); TFE_DeleteOp(func_op);
} break; } break;
case tensorflow::AttrValue::kList: case tensorflow::AttrValue::kList: {
TF_FALLTHROUGH_INTENDED; // String
if (const int s_size = default_value.list().s_size()) {
absl::InlinedVector<const void*, 4> values_vector;
absl::InlinedVector<size_t, 4> lengths_vector;
for (int i = 0; i < s_size; ++i) {
const string& v = default_value.list().s(i);
values_vector.push_back(v.data());
lengths_vector.push_back(v.size());
}
TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
lengths_vector.data(), s_size);
}
// Int
if (const int i_size = default_value.list().i_size()) {
absl::InlinedVector<int64_t, 4> i_vector;
for (int i = 0; i < i_size; ++i) {
i_vector.push_back(default_value.list().i(i));
}
TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
}
// Float
if (const int f_size = default_value.list().f_size()) {
absl::InlinedVector<float, 4> f_vector;
for (int i = 0; i < f_size; ++i) {
f_vector.push_back(default_value.list().f(i));
}
TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
}
// Bool
if (const int b_size = default_value.list().b_size()) {
absl::InlinedVector<unsigned char, 4> b_vector;
for (int i = 0; i < b_size; i++) {
b_vector.push_back(default_value.list().b(i));
}
TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
}
// Type
if (const int type_size = default_value.list().type_size()) {
absl::InlinedVector<unsigned int, 4> type_vector;
for (int i = 0; i < type_size; ++i) {
type_vector.push_back(default_value.list().type(i));
}
TFE_OpSetAttrTypeList(
op, attr_name,
reinterpret_cast<const TF_DataType*>(type_vector.data()),
type_size);
}
// Rest are not supported.
if (default_value.list().shape_size() > 0 ||
default_value.list().func_size() > 0 ||
default_value.list().tensor_size() > 0) {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Unable to get setfor default value: ",
default_value.DebugString())
.data());
}
} break;
case tensorflow::AttrValue::kTensor: case tensorflow::AttrValue::kTensor:
TF_FALLTHROUGH_INTENDED; TF_FALLTHROUGH_INTENDED;
case tensorflow::AttrValue::kPlaceholder: case tensorflow::AttrValue::kPlaceholder:
@ -1612,19 +1660,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
return status.status; return status.status;
} }
tensorflow::Status Execute(tensorflow::EagerOperation* op, tensorflow::Status Execute(const tensorflow::EagerOperation* op,
tensorflow::TensorHandle** retvals, tensorflow::TensorHandle** retvals,
int* num_retvals) override { int* num_retvals) override {
std::vector<TFE_TensorHandle*> inputs;
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals); std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status; TF_Status status;
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
info_); info_);
if (status.status.ok()) { if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) { for (int i = 0; i < *num_retvals; ++i) {
@ -1634,10 +1675,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_DeleteTensorHandle(outputs[i]); TFE_DeleteTensorHandle(outputs[i]);
} }
} }
for (auto inp : inputs) {
TFE_DeleteTensorHandle(inp);
}
return status.status; return status.status;
} }

View File

@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
// Placement policy which silently copies int32 tensors but not other dtypes. // Placement policy which silently copies int32 tensors but not other dtypes.
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy; } TFE_ContextDevicePlacementPolicy;
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h) // LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
// Sets the default execution mode (sync/async). Note that this can be // Sets the default execution mode (sync/async). Note that this can be
// overridden per thread using TFE_ContextSetExecutorForThread. // overridden per thread using TFE_ContextSetExecutorForThread.
@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op;
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
const char* op_or_function_name, const char* op_or_function_name,
TF_Status* status); TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
// Returns the op or function name `op` will execute.
//
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
TF_Status* status); TF_Status* status);
// The returned string remains valid throughout the lifetime of 'op'. // The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op,
TF_Status* status); TF_Status* status);
// When 'enable' is set to 1, and if TensorFlow library is built with XLA
// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
//
// If the library is not built with XLA support, this call would be a no-op.
TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
unsigned char enable);
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input,
TF_Status* status); TF_Status* status);
@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op,
int num_inputs, int num_inputs,
TF_Status* status); TF_Status* status);
// Fetches the current number of inputs attached to `op`.
//
// Does not use the operation's definition to determine how many inputs should
// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an
// already-finalized operation.
//
// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat
// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a
// particular named input list, which may only be part of the op's inputs).
TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op,
TF_Status* status);
// Returns a borrowed reference to one of `op`'s inputs. Use
// `TFE_TensorHandleCopySharingTensor` to make a new reference.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op,
int index,
TF_Status* status);
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op,
const char* attr_name, const char* attr_name,
unsigned char* is_list, unsigned char* is_list,

View File

@ -22,9 +22,6 @@ limitations under the License.
#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/jit/xla_device.h"
#endif // TENSORFLOW_EAGER_USE_XLA
using tensorflow::string; using tensorflow::string;
@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr; return nullptr;
} }
#ifdef TENSORFLOW_EAGER_USE_XLA
auto* device = absl::get<tensorflow::Device*>(handle->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
if (xla_device != nullptr) {
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<tensorflow::int64> shape_to_log =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
<< padded_shape.DebugString();
}
}
if (padded_shape.IsTuple()) {
if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
}
// shape0 is not a const& because we will assign it to padded_shape below.
// It is illegal to assign a part of a message to itself.
xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
}
// Since the only case we handle here are two equal subshapes, we
// simply return one of them. The caller will interpret it as this
// shape directly storing the 64bit types. This approximation is good
// enough for this API's debugging use case.
padded_shape = shape0;
}
int rank = padded_shape.dimensions_size();
std::vector<tensorflow::int64> dev_dims;
dev_dims.reserve(rank);
if (rank == 1) {
// Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
dev_dims.push_back(padded_shape.dimensions(0));
} else {
for (int i = rank - 1; i >= 0; --i) {
tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i);
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
status->status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<tensorflow::int64> dev_dims = std::vector<tensorflow::int64> dev_dims =
TensorShapeAsVector(*handle, &status->status); TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) { if (!status->status.ok()) {

View File

@ -121,25 +121,6 @@ string AddVariablesFunction() {
return def.SerializeAsString(); return def.SerializeAsString();
} }
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) { void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3); tensorflow::ServerDef server_def = GetServerDef(3);
@ -182,9 +163,8 @@ void TestFunctionWithPackedInput(const bool remote) {
// Add a sync point in order to make sure that variables have been initialized // Add a sync point in order to make sure that variables have been initialized
// before the function execution starts. // before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed. TFE_ContextAsyncWait(ctx, status);
VarIsInitialized(ctx, h1); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle. // Pack 3 variable handles into one TFE_TensorHandle.
// When remote is false, function device is placed on task0. Handle types are // When remote is false, function device is placed on task0. Handle types are
@ -396,6 +376,8 @@ TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr); EXPECT_NE(var_handle, nullptr);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string function_def = VariableAddFunction(); const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
@ -517,6 +499,8 @@ void TestDistributedFunctionCancellation(bool inject_error) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr); EXPECT_NE(var_handle, nullptr);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string function_def = inject_error ? VariableAddFunctionWithGraphError() const string function_def = inject_error ? VariableAddFunctionWithGraphError()
: VariableAddFunction(); : VariableAddFunction();
@ -561,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false); TestDistributedFunctionCancellation(false);
} }
TEST(CAPI, DistributedFunctionCancelledOnError) { // TODO(b/170399182): Update test once an alternative to using the function
// optimization hook is in place.
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true); TestDistributedFunctionCancellation(true);
} }

View File

@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
} }
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
} }
uint64_t TFE_GetContextId(TFE_Context* ctx) { uint64_t TFE_GetContextId(TFE_Context* ctx) {
@ -486,29 +482,6 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
static_cast<void*>(sampler->sampler->GetCell(label1, label2))); static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
} }
void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
TFE_ContextMirroringPolicy policy) {
options->mirroring_policy = policy;
}
void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy));
}
// Note: this function looks up a thread local policy. So it should be called in
// the appropriate client thread. In particular, in async mode, it may not be
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
}
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
bool lazy_copy) { bool lazy_copy) {
options->lazy_remote_inputs_copy = lazy_copy; options->lazy_remote_inputs_copy = lazy_copy;
@ -567,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
} }
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor());
} }
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context = return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor());
} }
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace( auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name()); tensorflow::unwrap(ctx)->HostCPUParsedName());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length()); void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0); str.copy(static_cast<char*>(data), str.length(), 0);
@ -595,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) { TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context = auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) { if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound( status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name); "Unable to find FunctionDef with name: ", function_name);
@ -666,14 +631,26 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
} }
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); }
context->SetLogDevicePlacement(enable);
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return tensorflow::unwrap(h)->DeviceType(&status->status);
}
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return tensorflow::unwrap(h)->DeviceId(&status->status);
} }

View File

@ -265,33 +265,6 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
// LINT.IfChange
// Note: Keep in sync with internal copy of enum in eager/context.h.
typedef enum TFE_ContextMirroringPolicy {
// Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
// copies with their own lifetime.
TFE_MIRRORING_NONE = 0,
// Mirroring any remote tensor handles, associating them with the lifetime of
// the local TensorHandle.
TFE_MIRRORING_ALL = 1,
} TFE_ContextMirroringPolicy;
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetMirroringPolicy(
TFE_ContextOptions*, TFE_ContextMirroringPolicy);
// Sets a thread-local mirroring policy. After this call, other calls to
// TFE_Execute in the same thread will use the mirroring policy specified here
// instead of the mirroring policy used to construct the context. This has no
// effect on the mirroring policy used by other program threads.
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context*, TFE_ContextMirroringPolicy);
// Returns the mirroring policy to be used by this context in the current
// thread.
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context*);
// Sets whether to copy the remote inputs of a function lazily. // Sets whether to copy the remote inputs of a function lazily.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TFE_ContextOptions*, bool lazy_copy); TFE_ContextOptions*, bool lazy_copy);
@ -441,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a reference to `op`'s attributes. The returned reference is only valid // Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive. // while `op` is alive.
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op); TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
// Add attributes in `attrs` to `op`. // Add attributes in `attrs` to `op`.
// //
// Does not overwrite or update existing attributes, but adds new ones. // Does not overwrite or update existing attributes, but adds new ones.
@ -462,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
size_t proto_len, size_t proto_len,
TF_Status* status); TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 2 // TODO(b/166642410): It would be nice, for custom devices and for other users,
// to have a non-string representation of devices (TF_Device) extracted from
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
#define TFE_CUSTOM_DEVICE_VERSION 3
// Struct to be filled in // Struct to be filled in
typedef struct TFE_CustomDevice { typedef struct TFE_CustomDevice {
@ -481,9 +458,16 @@ typedef struct TFE_CustomDevice {
void* device_info); void* device_info);
// Method to execute an operation. // Method to execute an operation.
void (*execute)(TFE_Context* context, int num_inputs, //
TFE_TensorHandle** inputs, const char* operation_name, // Arguments provide enough information to reconstruct the original `TFE_Op`,
const TFE_OpAttrs* attributes, int* num_outputs, // or construct a transformed version, by inspecting the passed `op`.
//
// TFE_OpGetDevice(op) records the original placement of the operation. It may
// be an empty string if no device was explicitly requested, but will
// otherwise be the name of this custom device. Ops are placed onto a custom
// device if any of their inputs are on that custom device, but custom devices
// are free to set a bad status in order to require explicit placement.
void (*execute)(const TFE_Op* op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info); TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device. // Method to delete a device.
@ -569,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable, unsigned char enable,
TF_Status* status); TF_Status* status);
// Returns the device type of the operation that produced `h`.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
TFE_TensorHandle* h, TF_Status* status);
// Returns the device ID of the operation that produced `h`.
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST(CAPI, Function_ident_XLA_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
TF_OperationDescription* arg_descr =
TF_NewOperation(function_graph, "Placeholder", "arg");
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
TF_Status* status = TF_NewStatus();
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_OperationDescription* id_descr =
TF_NewOperation(function_graph, "Identity", "id");
TF_SetAttrType(id_descr, "T", TF_INT32);
TF_AddInput(id_descr, {arg, 0});
TF_Operation* id = TF_FinishOperation(id_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_Output input{arg, 0};
TF_Output output{id, 0};
TF_Function* fn =
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
&output, nullptr, nullptr, "test", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteGraph(function_graph);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextAddFunction(ctx, fn, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteFunction(fn);
for (bool async : {false, true, false}) {
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
TFE_Executor* executor = TFE_NewExecutor(async);
TFE_ContextSetExecutorForThread(ctx, executor);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
TF_Tensor* t =
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteTensor(t);
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_OpAddInput(op, h, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
// Now run it via XLA.
TFE_OpSetXLACompilation(op, true);
std::vector<TFE_TensorHandle*> result;
result.push_back(nullptr);
int num_retvals = 1;
TFE_Execute(op, result.data(), &num_retvals, status);
TFE_DeleteOp(op);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
ASSERT_EQ(num_retvals, 1);
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
TFE_ContextSetExecutorForThread(ctx, old_executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_DeleteExecutor(old_executor);
TFE_DeleteTensorHandle(h);
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
#endif // TENSORFLOW_EAGER_USE_XLA
void Executor_MatMul_CPU(bool async) { void Executor_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -491,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
TEST(CAPI, TensorHandleNullptr) {
TFE_TensorHandle* h = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_type, nullptr);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int device_id = TFE_TensorHandleDeviceID(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_id, -1);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
}
TEST(CAPI, TensorHandleDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_DeleteOp(shape_op);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleDefaults) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
h_default, ctx, "/device:CPU:0", status.get());
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
TFE_DeleteTensorHandle(h_default);
TFE_DeleteTensorHandle(h_cpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -32,7 +32,6 @@ struct TFE_ContextOptions {
bool async = false; bool async = false;
TFE_ContextDevicePlacementPolicy device_placement_policy{ TFE_ContextDevicePlacementPolicy device_placement_policy{
TFE_DEVICE_PLACEMENT_SILENT}; TFE_DEVICE_PLACEMENT_SILENT};
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
// If true, lazily copy the remote inputs of a function to the target devices. // If true, lazily copy the remote inputs of a function to the target devices.
bool lazy_remote_inputs_copy = true; bool lazy_remote_inputs_copy = true;
// If true, use TFRT backend // If true, use TFRT backend

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string> #include <string>
// clang-format off // clang-format off
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
// clang-format on // clang-format on
@ -876,89 +877,6 @@ TEST(CAPI, Execute_Min_CPU) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
#ifdef TENSORFLOW_EAGER_USE_XLA
void Execute_MatMul_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_OpSetXLACompilation(matmul, true);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
// Running a primitive TF operator via XLA is not yet supported.
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
void Execute_Min_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = MinOp(ctx, input, axis);
TFE_OpSetXLACompilation(minOp, true);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(minOp, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(minOp);
TFE_DeleteTensorHandle(input);
TFE_DeleteTensorHandle(axis);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float output[2] = {0};
EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
#endif // TENSORFLOW_EAGER_USE_XLA
void ExecuteWithTracing(bool async) { void ExecuteWithTracing(bool async) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1274,6 +1192,68 @@ TEST(CAPI, StringAttributes) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
// Same test as above, expect use SetOpAttrValueScalar to set attrs.
TEST(CAPI, TestTFE_SetOpAttrs) {
// Test that TFE_OpSetAttrString doesn't hold on to the value after it
// returns.
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::vector<int64_t> dims(4, 1);
TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* tensor =
TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
float tensor_data[] = {1};
memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, tensor_handle, status);
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(tensor_handle);
tensorflow::AttrValue i_list_values;
for (int i = 0; i < 4; ++i) {
i_list_values.mutable_list()->add_i(1);
}
SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
tensorflow::AttrValue padding_value;
*padding_value.mutable_s() = "VALID";
tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
tensorflow::AttrValue data_format_value;
*data_format_value.mutable_s() = "NHWC";
tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
status);
TFE_OpSetAttrType(op, "T", TF_FLOAT);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(op, &retvals[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
tensor = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(4, TF_TensorByteSize(tensor));
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(op);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -1620,4 +1600,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);
} }
// Needs to work with a const TFE_Op since custom devices should not modify the
// op they are called with.
TFE_Op* CloneOp(const TFE_Op* other) {
TF_Status* status = TF_NewStatus();
TFE_Context* context = TFE_OpGetContext(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char* op_name = TFE_OpGetName(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* ret = TFE_NewOp(context, op_name, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char* device = TFE_OpGetDevice(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetDevice(ret, device, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
int num_inputs = TFE_OpGetFlatInputCount(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (int input_index = 0; input_index < num_inputs; ++input_index) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(ret, input, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TF_DeleteStatus(status);
return ret;
}
TEST(CAPI, TestTFE_OpRecreation) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
// Clone an op with attributes and a device set.
TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetDevice(original_var_op,
"/job:localhost/replica:0/task:0/device:CPU:0", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* cloned = CloneOp(original_var_op);
EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
std::string(TFE_OpGetDevice(cloned, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_TensorHandle* ret;
TFE_Execute(cloned, &ret, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(ret);
// Clone an op with inputs and no device set.
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInputList(original_identity, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* cloned_identity = CloneOp(original_identity);
EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
num_retvals = 2;
TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(identity_ret[0]);
TFE_DeleteTensorHandle(identity_ret[1]);
TFE_DeleteOp(cloned_identity);
TFE_DeleteOp(original_identity);
TFE_DeleteOp(original_var_op);
TFE_DeleteOp(cloned);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
} // namespace } // namespace

View File

@ -17,12 +17,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/cluster.pb.h"
using tensorflow::string; using tensorflow::string;
using tensorflow::tstring;
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) { TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
float data[] = {value}; float data[] = {value};
@ -36,6 +40,19 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
return th; return th;
} }
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
tstring* data = static_cast<tstring*>(TF_TensorData(t));
*data = value;
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) { TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
int data[] = {value}; int data[] = {value};
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ #define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -28,6 +29,10 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
// Return a tensor handle containing a bool scalar // Return a tensor handle containing a bool scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value); TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
// Return a tensor handle containing a tstring scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value);
// Return a tensor handle containing a 2x2 matrix of doubles // Return a tensor handle containing a 2x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx); TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);

View File

@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() {
return *factories; return *factories;
} }
static const char* default_factory = "<unset>"; static tracing::FactoryFunction default_factory;
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) || assert((!GetFactories().count(name)) ||
@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
GetFactories()[name] = factory; GetFactories()[name] = factory;
} }
void SetDefaultTracingEngine(const char* name) { default_factory = name; } Status SetDefaultTracingEngine(const char* name) {
auto entry = GetFactories().find(name);
static TracingContext* CreateTracingExecutionContext(const char* fn_name, if (entry != GetFactories().end()) {
TF_Status* s) { default_factory = GetFactories().find(name)->second;
auto entry = GetFactories().find(default_factory); return Status::OK();
if (entry != GetFactories().end()) return entry->second(fn_name, s); }
string msg = absl::StrCat( string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '", "No tracing engine factory has been registered with the key '", name,
default_factory, "' (available: "); "' (available: ");
// Ensure deterministic (sorted) order in the error message // Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted; std::set<string> factories_sorted;
for (const auto& factory : GetFactories()) for (const auto& factory : GetFactories())
@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name,
} }
msg += ")"; msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); return errors::InvalidArgument(msg.c_str());
}
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
if (default_factory) {
return default_factory(fn_name, s);
}
Set_TF_Status_from_Status(
s, errors::FailedPrecondition("default_factory is nullptr"));
return nullptr; return nullptr;
} }
@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext;
using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingOperation;
using tensorflow::tracing::TracingTensorHandle; using tensorflow::tracing::TracingTensorHandle;
void TF_SetTracingImplementation(const char* name) { void TF_SetTracingImplementation(const char* name, TF_Status* s) {
SetDefaultTracingEngine(name); Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
} }
// Creates a new TensorFlow function, it is an execution context attached to a // Creates a new TensorFlow function, it is an execution context attached to a

View File

@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction;
// This allows the client to swap the implementation of the tracing engine. // This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined // Any future call to TF_CreateFunction will use the implementation defined
// here. // here.
void TF_SetTracingImplementation(const char* name); void TF_SetTracingImplementation(const char* name, TF_Status*);
// Creates a new TensorFlow function. A Function is an execution context, and as // Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing // such it can trace operations through TF_ExecuteOperation. After completing

View File

@ -365,9 +365,10 @@ class GraphContext : public TracingContext {
} }
auto s = TF_NewStatus(); auto s = TF_NewStatus();
func->func = TF_GraphToFunction( func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, name_.data(), s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s); TF_DeleteStatus(s);
*f = func.release(); *f = func.release();
@ -391,7 +392,7 @@ class GraphContext : public TracingContext {
private: private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_; std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_; std::vector<TF_Output> inputs_;
const char* name_; string name_;
}; };
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
// Register the tracing implemented in this file as the default tracing engine. // Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] { static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory); RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef"); SetDefaultTracingEngine("graphdef").IgnoreError();
return true; return true;
}(); }();

View File

@ -120,7 +120,7 @@ class TracingContext : public AbstractContext {
}; };
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name); Status SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name, void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory); FactoryFunction factory);
} // namespace tracing } // namespace tracing

View File

@ -22,10 +22,15 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
using tensorflow::Status;
using tensorflow::string; using tensorflow::string;
using tensorflow::TF_StatusPtr;
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -37,7 +42,10 @@ class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected: protected:
void SetUp() override { void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam())); TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
} }
}; };

View File

@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
bool arrived = false; bool arrived = false;
bool executed = false; bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get()); RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
ASSERT_FALSE(arrived); ASSERT_FALSE(arrived);
@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
bool executed = false; bool executed = false;
const char* custom_device_name = const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0"; "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed, RegisterLoggingDevice(context.get(), custom_device_name,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get()); status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
bool arrived = false; bool arrived = false;
bool executed = false; bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device. // Create a variable handle placed on the custom device.
@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
bool arrived = false; bool arrived = false;
bool executed = false; bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device. // Create a variable handle placed on the custom device.
@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false; bool arrived = false;
bool executed = false; bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed, RegisterLoggingDevice(context.get(), custom0,
/*strict_scope_placement=*/false, &arrived, &executed,
status.get()); status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed, RegisterLoggingDevice(context.get(), custom1,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get()); status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1)); ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails. // Custom device: mix of custom/physical places the op on the custom device.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1; num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get())); EXPECT_TRUE(executed);
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE( TFE_DeleteTensorHandle(retval);
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
// Explicit placement still forces the op onto the requested device
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
// configured to do)
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
} }
TEST(CUSTOM_DEVICE, InvalidRegistrationError) { TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false; bool arrived = false;
bool executed = false; bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed, RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
/*strict_scope_placement=*/true, &arrived, &executed,
status.get()); status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get()); << TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&arrived, &executed, status.get()); &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get()); << TF_Message(status.get());
RegisterLoggingDevice(
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
} }

View File

@ -33,6 +33,9 @@ struct LoggingDevice {
bool* arrived_flag; bool* arrived_flag;
// Set to true whenever an operation is executed // Set to true whenever an operation is executed
bool* executed_flag; bool* executed_flag;
// If true, only explicit op placements are accepted. If false, uses
// type-based dispatch.
bool strict_scope_placement;
}; };
struct LoggedTensor { struct LoggedTensor {
@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
return nullptr; return nullptr;
} }
void LoggingDeviceExecute(TFE_Context* context, int num_inputs, void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) { void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info); LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
if (dev->strict_scope_placement && *requested_placement == '\0') {
TF_SetStatus(s, TF_INTERNAL,
"Ops must be placed on the device explicitly, or their inputs "
"first copied to other devices.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
TFE_Op* op(TFE_NewOp(context, operation_name, s)); TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return; if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes); TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
for (int j = 0; j < num_inputs; ++j) { for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j]; TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
if (TF_GetCode(s) != TF_OK) return;
const char* input_device = TFE_TensorHandleDeviceName(input, s); const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return; if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) { if (dev->device_name == input_device) {
@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) {
} // namespace } // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name, void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag, bool strict_scope_placement, bool* arrived_flag,
TF_Status* status) { bool* executed_flag, TF_Status* status) {
TFE_CustomDevice custom_device; TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice; custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
device->executed_flag = executed_flag; device->executed_flag = executed_flag;
device->device_name = name; device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
device->strict_scope_placement = strict_scope_placement;
TFE_RegisterCustomDevice(context, custom_device, name, device, status); TFE_RegisterCustomDevice(context, custom_device, name, device, status);
} }
@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag,
logging_device->device_name = name; logging_device->device_name = name;
logging_device->underlying_device = logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0"; "/job:localhost/replica:0/task:0/device:CPU:0";
logging_device->strict_scope_placement = true;
*device_info = reinterpret_cast<void*>(logging_device); *device_info = reinterpret_cast<void*>(logging_device);
} }

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name, void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag, bool strict_scope_placement, bool* arrived_flag,
TF_Status* status); bool* executed_flag, TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag, void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device, bool* executed_flag, TFE_CustomDevice** device,
void** device_info); void** device_info);

View File

@ -109,7 +109,8 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
// Gets DLPack's DLContext from eager tensor handle. // Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx; DLContext ctx;
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status); const char* device_name =
tensorflow::unwrap(h)->BackingDeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type; std::string device_type = parsed_name.type;
@ -248,21 +249,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
} }
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
auto tf_dlm_context = GetDlContext(h, status);
if (!status->status.ok()) {
return nullptr;
}
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
if (!status->status.ok()) {
return nullptr;
}
const Tensor* tensor = GetTensorFromHandle(h, status); const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype()); TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto tf_dlm_type = GetDlDataType(data_type, status);
if (!status->status.ok()) {
return nullptr;
}
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref; tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter; dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status); dlm_tensor->dl_tensor.ctx = tf_dlm_context;
int ndim = tensor->dims(); int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim; dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); dlm_tensor->dl_tensor.data = tf_dlm_data;
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status); dlm_tensor->dl_tensor.dtype = tf_dlm_type;
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides; std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
@ -275,13 +291,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
} }
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; dlm_tensor->dl_tensor.shape = shape_arr->data();
// There are two ways to represent compact row-major data // There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored. // 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data. // 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides // Here we choose option 2, since some frameworks didn't handle the strides
// argument properly. // argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; dlm_tensor->dl_tensor.strides = stride_arr->data();
dlm_tensor->dl_tensor.byte_offset = dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here 0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor); return static_cast<void*>(dlm_tensor);

View File

@ -0,0 +1,201 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
using namespace std;
// ================== Helper functions =================
// Fills data with values [start,end) with given step size.
void Range(vector<int>* data, int start, int end, int step = 1) {
for (int i = start; i < end; i += step) {
(*data)[i] = i;
}
}
// Returns AbstractTensorHandlePtr containing [0, ..., n-1].
AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) {
vector<int> vals(n);
int64_t vals_shape[] = {n};
Range(&vals, 0, n);
AbstractTensorHandlePtr r =
GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1);
return r;
}
// Fills out_dims with the dimensions of the given tensor.
void GetDims(const TF_Tensor* t, int64_t* out_dims) {
int num_dims = TF_NumDims(t);
for (int i = 0; i < num_dims; i++) {
out_dims[i] = TF_Dim(t, i);
}
}
// Runs model as is if output is a scalar,
// else sums the output tensor before returning.
Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle*> inputs,
absl::Span<AbstractTensorHandle*> outputs,
bool use_function) {
GradientRegistry registry;
std::vector<AbstractTensorHandle*> model_outputs(1);
// Run the model.
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
absl::MakeSpan(model_outputs), use_function,
registry));
AbstractTensorHandle* model_out = model_outputs[0];
TF_Tensor* model_out_tensor;
TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor));
int num_dims_out = TF_NumDims(model_out_tensor);
// If the output is a scalar, then return the scalar output
if (num_dims_out == 0) {
outputs[0] = model_out;
return Status::OK();
}
// Else, reduce sum the output to get a scalar
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
AbstractTensorHandlePtr sum_dims =
GetRangeTensorHandleUtil(ctx, num_dims_out);
// Reduce sum the output on all dimensions.
std::vector<AbstractTensorHandle*> sum_inputs(2);
sum_inputs[0] = model_out;
sum_inputs[1] = sum_dims.get();
TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs),
absl::MakeSpan(model_outputs), "sum_output"));
outputs[0] = model_outputs[0];
return Status::OK();
}
// ========================= End Helper Functions==============================
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle*> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad) {
AbstractTensorHandle* theta =
inputs[input_index]; // parameter we are grad checking
// Convert from AbstractTensor to TF_Tensor.
TF_Tensor* theta_tensor;
TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor));
// Get number of elements and fill data.
int num_elems = TF_TensorElementCount(theta_tensor);
vector<float> theta_data(num_elems);
memcpy(theta_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
// Initialize space for the numerical gradient.
vector<float> dtheta_approx(num_elems);
// Get theta shape and store in theta_dims.
int num_dims = TF_NumDims(theta_tensor);
vector<int64_t> theta_dims(num_dims);
GetDims(theta_tensor, theta_dims.data());
// Initialize auxilary data structures.
vector<float> thetaPlus_data(num_elems);
vector<float> thetaMinus_data(num_elems);
std::vector<AbstractTensorHandle*> f_outputs(1);
// Numerical Grad Check
for (int i = 0; i < num_elems; i++) {
// Get relative epsilon value
float epsilon =
std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0
AbstractTensorHandlePtr two_eps =
GetScalarTensorHandleUtil(ctx, 2 * epsilon);
// Initialize theta[i] + epsilon.
memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaPlus_data[i] += epsilon;
AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims);
// Initialize theta[i] - epsilon.
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaMinus_data[i] -= epsilon;
AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims);
// Get f(theta + eps):
inputs[input_index] = thetaPlus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fPlus = f_outputs[0];
// Get f(theta - eps):
inputs[input_index] = thetaMinus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fMinus = f_outputs[0];
// Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
TF_RETURN_IF_ERROR(
ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"));
AbstractTensorHandle* fDiff = f_outputs[0];
// Calculate using the difference quotient definition:
// (f(theta + eps) - f(theta - eps)) / (2 * eps).
TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()},
absl::MakeSpan(f_outputs),
"diff_quotient"));
AbstractTensorHandle* diff_quotient = f_outputs[0];
TF_Tensor* grad_tensor;
TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor));
float grad_data[1];
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
TF_TensorByteSize(grad_tensor));
dtheta_approx[i] = grad_data[0];
}
// Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
/* Returns numerical grad inside `dtheta_approx` given `forward` model and
* parameter specified by `input_index`.
*
* I.e. if y = <output of the forward model> and w = inputs[input_index],
* this will calculate dy/dw numerically.
*
* `use_function` indicates whether to use graph mode(true) or eager(false).
*
* `numerical_grad` is the pointer to the AbstractTensorHandle* which will
* hold the numerical gradient data at the end of the function.
*/
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle*> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad);
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,265 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
class GradientCheckerTest
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(A.get());
inputs.push_back(B.get());
AbstractTensorHandle* grad_approx;
Status s = CalcNumericalGrad(
ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &grad_approx);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(grad_approx, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-2;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(expected_dA[j], result_data[j], tolerance);
}
TF_DeleteTensor(gt);
}
TEST_P(GradientCheckerTest, TestGradCheckMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
// Will perform z = x*y.
// dz/dx = y
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(x.get());
inputs.push_back(y.get());
AbstractTensorHandle* g;
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[1] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2);
TF_DeleteTensor(gt);
}
TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
/** Test to show how to use this API with analytical gradients:
*
* We have `SoftmaxLossGradModel`, which is a wrapper for the
* Softmax analytical gradient found in c/experimental/nn_grads.
*
* We will use the GradientChecker by applying finite differences
* to the forward pass wrapped in `SoftmaxModel` and verify that
* both the analytical and numerical gradients are relatively
* close.
*
*/
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = scores
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// y = labels
int y_vals[] = {1, 0, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(X.get());
inputs.push_back(y.get());
// Run analytical gradient and get its data.
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs),
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float danalytical[9] = {0}; // Contains data from analytical gradient.
memcpy(&danalytical[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
// Run numerical gradient approximation using the GradientChecker API.
AbstractTensorHandle* g; // Will contain numerical approximation data.
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float dnumerical[9] = {0};
memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt));
// Now compare the two implementations:
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2);
}
// Only Unref() first output as 2nd is nullptr grad for labels
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
TF_DeleteTensor(gt);
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -122,14 +122,12 @@ int64 ToId(AbstractTensorHandle* t) {
return static_cast<int64>(reinterpret_cast<uintptr_t>(t)); return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
} }
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx) TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
: handle_(handle), ctx_(ctx) {
handle_->Ref(); handle_->Ref();
} }
TapeTensor::TapeTensor(const TapeTensor& other) { TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_; handle_ = other.handle_;
handle_->Ref(); handle_->Ref();
ctx_ = other.ctx_;
} }
TapeTensor::~TapeTensor() { handle_->Unref(); } TapeTensor::~TapeTensor() { handle_->Unref(); }
@ -138,33 +136,7 @@ tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
tensorflow::DataType TapeTensor::GetDType() const { tensorflow::DataType TapeTensor::GetDType() const {
return handle_->DataType(); return handle_->DataType();
} }
AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
AbstractTensorHandle* TapeTensor::OnesLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; } AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
@ -219,6 +191,23 @@ Status TapeVSpace::CallBackwardFunction(
&ctx, incoming_gradients, result); &ctx, incoming_gradients, result);
} }
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const {
AbstractOperationPtr op(ctx_->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
}
TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
*result = outputs[0];
return Status::OK();
}
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const { int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
return ToId(tensor); return ToId(tensor);
@ -226,7 +215,7 @@ int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
// Converts a Gradient to a TapeTensor. // Converts a Gradient to a TapeTensor.
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const { TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
return TapeTensor(g, ctx_); return TapeTensor(g);
} }
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {} void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
@ -242,6 +231,7 @@ namespace internal {
Status Reset(AbstractOperation* op_, const char* op, Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) { const char* raw_device_name, ForwardOperation* forward_op_) {
forward_op_->op_name = op; forward_op_->op_name = op;
forward_op_->attrs.Reset(op);
return op_->Reset(op, raw_device_name); return op_->Reset(op, raw_device_name);
} }
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
@ -418,9 +408,14 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]); forward_op_->outputs.push_back(retvals[i]);
} }
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
// attributes. Number type attrs and DataType attrs work fine without this.
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_->attrs.BuildNodeDef();
std::vector<TapeTensor> tape_tensors; std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) { for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx)); tape_tensors.push_back(TapeTensor(t));
} }
tape->RecordOperation( tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes, op_->Name(), tape_tensors, input_ids, input_dtypes,

View File

@ -80,7 +80,6 @@ struct ForwardOperation {
std::vector<AbstractTensorHandle*> inputs; std::vector<AbstractTensorHandle*> inputs;
std::vector<AbstractTensorHandle*> outputs; std::vector<AbstractTensorHandle*> outputs;
AttrBuilder attrs; AttrBuilder attrs;
AbstractContext* ctx;
}; };
// Interface for building default zeros gradients for op outputs which are // Interface for building default zeros gradients for op outputs which are
@ -181,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute // allow us to trace the data dependencies between operations and hence compute
// gradients. // gradients.
// //
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation // `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered // of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op. // for each op.
@ -193,20 +188,19 @@ int64 ToId(AbstractTensorHandle* t);
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr? // TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
class TapeTensor { class TapeTensor {
public: public:
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx); explicit TapeTensor(AbstractTensorHandle* handle);
TapeTensor(const TapeTensor& other); TapeTensor(const TapeTensor& other);
~TapeTensor(); ~TapeTensor();
tensorflow::int64 GetID() const; tensorflow::int64 GetID() const;
tensorflow::DataType GetDType() const; tensorflow::DataType GetDType() const;
AbstractTensorHandle* OnesLike() const;
AbstractTensorHandle* ZerosLike() const; AbstractTensorHandle* ZerosLike() const;
AbstractTensorHandle* GetHandle() const;
private: private:
AbstractTensorHandle* handle_; AbstractTensorHandle* handle_;
// The context where OnesLike ops are to be created.
AbstractContext* ctx_;
}; };
// Vector space for actually computing gradients. Implements methods for calling // Vector space for actually computing gradients. Implements methods for calling
@ -234,6 +228,10 @@ class TapeVSpace
gtl::ArraySlice<AbstractTensorHandle*> output_gradients, gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override; std::vector<AbstractTensorHandle*>* result) const override;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const override;
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.
int64 TensorId(AbstractTensorHandle* tensor) const override; int64 TensorId(AbstractTensorHandle* tensor) const override;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
@ -26,7 +27,9 @@ limitations under the License.
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/array_grad.h" #include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
@ -38,84 +41,32 @@ namespace gradients {
namespace internal { namespace internal {
namespace { namespace {
using std::vector; using std::vector;
using tensorflow::TF_StatusPtr;
using tracing::TracingOperation; using tracing::TracingOperation;
class CppGradients class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected: protected:
void SetUp() override { void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam())); TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
} }
}; };
Status RegisterGradients(GradientRegistry* registry) { Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); // TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
// AddV2Registerer.
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
return Status::OK(); return Status::OK();
} }
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `exp(inputs[0])` and records it on the tape.
Status Exp(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr exp_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(exp_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
}
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `IdentityN(inputs)` and records it on the tape.
Status IdentityN(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(identity_n_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
->SetOpName("my_identity_n"));
}
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
int num_retvals = outputs.size();
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
tape, registry);
}
// Computes // Computes
// y = inputs[0] + inputs[1] // y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]}) // return grad(y, {inputs[0], inputs[1]})
@ -128,8 +79,10 @@ Status AddGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1); std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
registry)); // Compute x+y. TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
absl::MakeSpan(add_outputs),
"Add")); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -160,8 +113,9 @@ Status ExpGradModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1); std::vector<AbstractTensorHandle*> exp_outputs(1);
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
registry)); // Compute x+y. TF_RETURN_IF_ERROR(
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -179,6 +133,37 @@ Status ExpGradModel(AbstractContext* ctx,
return Status::OK(); return Status::OK();
} }
// Computes
// y = sqrt(inputs[0])
// return grad(y, {inputs[0]})
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sqrt_output : sqrt_outputs) {
sqrt_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes // Computes
// ignored, y = IdentityN(inputs[0], inputs[1]) // ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]}) // return grad(y, {inputs[0], inputs[1]})
@ -193,8 +178,9 @@ Status IdentityNGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[1])); tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2); vector<AbstractTensorHandle*> identity_n_outputs(2);
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs, AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
absl::MakeSpan(identity_n_outputs), registry)); TF_RETURN_IF_ERROR(ops::IdentityN(
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -214,6 +200,73 @@ Status IdentityNGradModel(AbstractContext* ctx,
return Status::OK(); return Status::OK();
} }
// Computes
// y = - inputs[0]
// return grad(y, {inputs[0]})
Status NegGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
std::vector<AbstractTensorHandle*> neg_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto neg_output : neg_outputs) {
neg_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes
// y = inputs[0] - inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status SubGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> sub_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
absl::MakeSpan(sub_outputs),
"Sub")); // Compute x-y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sub_output : sub_outputs) {
sub_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) { AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -448,6 +501,50 @@ TEST_P(CppGradients, TestExpGrad) {
result_tensor = nullptr; result_tensor = nullptr;
} }
TEST_P(CppGradients, TestSqrtGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = sqrt(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) { TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code: // Pseudo-code:
// //
@ -507,6 +604,161 @@ TEST_P(CppGradients, TestIdentityNGrad) {
result_tensor = nullptr; result_tensor = nullptr;
} }
TEST_P(CppGradients, TestNegGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = - x
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSubGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x - y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr t;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
t.reset(x_raw);
}
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
ForwardOperation forward_op;
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
/*raw_device_name=*/nullptr, &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
if (isa<TracingOperation>(check_numerics_op.get())) {
s = dyn_cast<TracingOperation>(check_numerics_op.get())
->SetOpName("check_numerics");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
s = AddInput(check_numerics_op.get(), t.get(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string message = "This is the way!";
s = SetAttrString(check_numerics_op.get(), "message", message.data(),
message.length(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
int num_retvals = 1;
std::vector<AbstractTensorHandle*> outputs(1);
GradientRegistry registry;
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
&num_retvals, &forward_op, tape.get(), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string read_message;
s = forward_op.attrs.Get("message", &read_message);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_EQ(read_message, message);
}
// TODO(b/164171226): Enable this test with tfrt after AddInputList is // TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN. // supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE

View File

@ -0,0 +1,317 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients_util.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
using namespace std;
Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
TFE_TensorHandle** result) {
float data[] = {value};
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return StatusFromTF_Status(status.get());
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val) {
AbstractTensorHandlePtr y;
AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx, val, &y_raw);
if (s.ok()) {
y.reset(y_raw);
}
return y;
}
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate) {
/* Update weights one by one using gradient update rule:
*
* w -= lr*grad[w]
*
* NOTE: assuming learning rate is positive
*/
int num_grads = grads.size();
vector<AbstractTensorHandle*> temp_outputs(1);
std::string update_str;
// Negate learning rate for gradient descent
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
absl::MakeSpan(temp_outputs),
"neg_lr")); // Compute -lr
learning_rate = temp_outputs[0];
for (int i = 0; i < num_grads; i++) {
// Compute dW = -lr * grad(w[i])
update_str = "update_mul_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
AbstractTensorHandle* dW = temp_outputs[0];
// Compute temp = weights[i] + dW
update_str = "update_add_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
// Update the weights
weights[i] = temp_outputs[0];
}
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
params->emplace_back(handle);
}
return Status::OK();
}
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,88 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gradients {
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor);
// Places data from `t` into *result_tensor.
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data.
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val);
// Performs gradient update for each weight using given learning rate.
Status UpdateWeights(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>& grads,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
// Runs given model in either graph or eager mode depending on value of
// use_function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry);
// Builds context and returns inside *ctx.
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients
} // namespace tensorflow

View File

@ -29,8 +29,25 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
class EagerExecutor;
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
enum ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default policy.
DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
};
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
// Abstract interface to a context. // Abstract interface to a context.
// //
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
// List attributes of available devices // List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0; virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished. // Block until all pending nodes are finished.
virtual Status AsyncWait() = 0; virtual Status AsyncWait() = 0;
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists. // already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
// Return the ParsedName of Host CPU device.
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
// Configure soft device placement policy.
virtual void SetAllowSoftPlacement(bool enable) = 0;
// Configure device placement policy logging.
virtual void SetLogDevicePlacement(bool enable) = 0;
// Sets the device placement policy for the current thread.
virtual void SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) = 0;
// Returns the device placement policy for the current thread.
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
// For LLVM style RTTI. // For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) { static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt; return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
} }
//===--------------------------------------------------------------------===//
// Following are legacy features in TF Eager Runtime.
// TODO(tf-runtime): Figure out a way to deprecate following features after
// migrated to TFRT.
//===--------------------------------------------------------------------===//
// Clear pending nodes in thread executors and kernel caches.
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Return the Eager Executor for current thread. Please note that Eager
// Executor is only used in current TF but not in TFRT.
virtual EagerExecutor& Executor() = 0;
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
protected: protected:
explicit ImmediateExecutionContext(AbstractContextKind kind) explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {} : AbstractContext(kind) {}

View File

@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation {
virtual Status InputLength(const char* input_name, int* length) = 0; virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0; virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
// Set stack trace to be used for potential async error reporting. // Set stack trace to be used for potential async error reporting.
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0; virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;

View File

@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
virtual const char* DeviceName(Status* status) const = 0; virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed. // Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(Status* status) const = 0; virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns the device type which created the handle.
virtual const char* DeviceType(Status* status) const = 0;
// Returns the device ID which created the handle.
virtual int DeviceId(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied. // Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual AbstractTensorInterface* Resolve(Status* status) = 0; virtual AbstractTensorInterface* Resolve(Status* status) = 0;

View File

@ -14,11 +14,11 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h" #include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h"
@ -33,12 +33,16 @@ namespace tensorflow {
namespace gradients { namespace gradients {
namespace internal { namespace internal {
namespace { namespace {
using tensorflow::TF_StatusPtr;
class CppGradients class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected: protected:
void SetUp() override { void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam())); TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
} }
}; };
@ -49,89 +53,10 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits", registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyLossRegisterer)); SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK(); return Status::OK();
} }
// ========================= Test Util Functions ==============================
// Get a scalar TensorHandle with given value
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given float values and dimensions
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
// =========================== Start Tests ================================
TEST_P(CppGradients, TestMatMulGrad) { TEST_P(CppGradients, TestMatMulGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -465,6 +390,12 @@ TEST_P(CppGradients, TestReluGrad) {
} }
TEST_P(CppGradients, TestSoftmaxLossGrad) { TEST_P(CppGradients, TestSoftmaxLossGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -533,6 +464,12 @@ TEST_P(CppGradients, TestSoftmaxLossGrad) {
} }
TEST_P(CppGradients, TestMNISTGrad) { TEST_P(CppGradients, TestMNISTGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx; AbstractContextPtr ctx;
@ -603,7 +540,6 @@ TEST_P(CppGradients, TestMNISTGrad) {
TF_TensorByteSize(dW1_tensor)); TF_TensorByteSize(dW1_tensor));
float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f}; float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
; // dLoss
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance); ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
} }
@ -643,7 +579,7 @@ TEST_P(CppGradients, TestScalarMul) {
AbstractTensorHandlePtr eta; AbstractTensorHandlePtr eta;
{ {
AbstractTensorHandle* x_raw = nullptr; AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.5f, &x_raw); Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
eta.reset(x_raw); eta.reset(x_raw);
} }
@ -681,6 +617,12 @@ TEST_P(CppGradients, TestScalarMul) {
} }
TEST_P(CppGradients, TestMNIST_Training) { TEST_P(CppGradients, TestMNIST_Training) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -733,7 +675,7 @@ TEST_P(CppGradients, TestMNIST_Training) {
// Set learning rate to be 1e-1 // Set learning rate to be 1e-1
AbstractTensorHandle* learning_rate = nullptr; AbstractTensorHandle* learning_rate = nullptr;
s = TestScalarTensorHandle(ctx.get(), 1e-1, &learning_rate); s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Train // Train
@ -765,13 +707,13 @@ TEST_P(CppGradients, TestMNIST_Training) {
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients, UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"), ::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false), /*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false))); /*executing_eagerly*/ ::testing::Values(true, false)));
#else #else
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients, UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"), ::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false), /*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false))); /*executing_eagerly*/ ::testing::Values(true, false)));
#endif #endif

View File

@ -24,136 +24,19 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace gradients {
namespace internal {
using std::vector; using std::vector;
using tracing::TracingOperation;
// ========================== Tape Ops ==============================
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
int num_retvals = 1;
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyLoss(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractTensorHandle* scores = inputs[0];
AbstractTensorHandle* labels = inputs[1];
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
int num_retvals = 2; // returns loss values and backprop
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
//===================== Test Models to run ========================= //===================== Test Models to run =========================
@ -169,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1); std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
registry)); // Compute x+y. TF_RETURN_IF_ERROR(
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -202,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1); vector<AbstractTensorHandle*> mm_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
/*transpose_b=*/false, registry)); // Compute x*y. absl::MakeSpan(mm_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -238,8 +124,9 @@ Status MNISTForwardModel(AbstractContext* ctx,
* hidden_layer = tf.nn.relu(mm_out_1) * hidden_layer = tf.nn.relu(mm_out_1)
* scores = tf.matmul(hidden_layer,W2) * scores = tf.matmul(hidden_layer,W2)
* softmax = * softmax =
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) return * tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
* scores, softmax * y_labels)
* return scores, softmax
* *
* Use this convention for inputs: * Use this convention for inputs:
* *
@ -257,24 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
tape->Watch(ToId(W2)); // Watch W2. tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
absl::MakeSpan(temp_outputs), "relu", absl::MakeSpan(temp_outputs),
registry)); // Compute Relu(X*W1) "relu")); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, TF_RETURN_IF_ERROR(ops::MatMul(
absl::MakeSpan(temp_outputs), "matmul1", tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
/*transpose_a=*/false, /*transpose_b=*/false, "matmul1",
registry)); // Compute W2*Relu(X*W1) /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0]; AbstractTensorHandle* scores = temp_outputs[0];
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( temp_outputs.resize(2);
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
"softmax_loss", registry)); // Compute Softmax(Scores,labels) tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss")); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0]; AbstractTensorHandle* loss_vals = temp_outputs[0];
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/true, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/true,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X tape->Watch(ToId(inputs[0])); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1); vector<AbstractTensorHandle*> relu_outputs(1);
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"relu0", registry)); // Relu(X) TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(relu_outputs),
"relu0")); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch scores. tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels. tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2); vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); // Watch W1. tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1. tape->Watch(ToId(W2)); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0]; AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1) absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0", registry)); "relu0"));
AbstractTensorHandle* hidden = temp_outputs[0]; AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, TF_RETURN_IF_ERROR(ops::MatMul(
absl::MakeSpan(temp_outputs), "matmul1", tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false, /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
registry)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0]; AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2); temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1) "softmaxloss")); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0]; AbstractTensorHandle* loss = temp_outputs[0];
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"scalarMul0", registry)); // Compute eta*A TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
absl::MakeSpan(temp_outputs),
"scalarMul0")); // Compute eta*A
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
@ -449,146 +347,69 @@ Status ScalarMulModel(AbstractContext* ctx,
return Status::OK(); return Status::OK();
} }
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* x = inputs[0];
AbstractTensorHandle* y = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
absl::MakeSpan(temp_outputs),
"mul0")); // Compute x*y
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
Status SoftmaxModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* x = inputs[0];
AbstractTensorHandle* labels = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
outputs[0] = temp_outputs[0]; // loss values
delete tape;
return Status::OK();
}
// ============================= End Models ================================ // ============================= End Models ================================
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads, } // namespace internal
vector<AbstractTensorHandle*>& weights, } // namespace gradients
AbstractTensorHandle* learning_rate) { } // namespace tensorflow
/* Update weights one by one using gradient update rule:
*
* w -= lr*grad[w]
*
* NOTE: assuming learning rate is positive
*/
Status s;
int num_grads = grads.size();
vector<AbstractTensorHandle*> temp_outputs(1);
std::string update_str;
// Negate learning rate for gradient descent
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
absl::MakeSpan(temp_outputs),
"neg_lr")); // Compute -lr
learning_rate = temp_outputs[0];
for (int i = 0; i < num_grads; i++) {
// Compute dW = -lr * grad(w[i])
update_str = "update_mul_" + std::to_string(i);
s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs),
update_str.c_str());
AbstractTensorHandle* dW = temp_outputs[0];
// Compute temp = weights[i] + dW
update_str = "update_add_" + std::to_string(i);
s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs),
update_str.c_str());
// Update the weights
weights[i] = temp_outputs[0];
}
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
params->emplace_back(handle);
}
return Status::OK();
}
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#include <memory> #include <memory>
#include "absl/types/span.h" #include "absl/types/span.h"
@ -24,50 +26,13 @@ limitations under the License.
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
using namespace tensorflow;
using namespace tensorflow::gradients;
using namespace tensorflow::gradients::internal;
// ========================== Tape Ops ============================== namespace tensorflow {
namespace gradients {
// Computes `inputs[0] + inputs[1]` and records it on the tape. namespace internal {
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` and records it on the tape.
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyLoss(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// ====================== End Tape Ops ============================
// Computes // Computes
// y = inputs[0] + inputs[1] // y = inputs[0] + inputs[1]
@ -121,26 +86,23 @@ Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry); const GradientRegistry& registry);
// Updates the weights for a neural network given incoming grads and learning Status MatMulModel(AbstractContext* ctx,
// rate absl::Span<AbstractTensorHandle* const> inputs,
Status UpdateWeights(AbstractContext* ctx, absl::Span<AbstractTensorHandle*> outputs,
std::vector<AbstractTensorHandle*>& grads, const GradientRegistry& registry);
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
AbstractContext* BuildFunction(const char* fn_name); Status MulModel(AbstractContext* ctx,
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry); const GradientRegistry& registry);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); Status SoftmaxModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_

View File

@ -1,3 +1,5 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"tf_cc_test", "tf_cc_test",
@ -103,7 +105,6 @@ cc_library(
hdrs = ["parallel_device_testlib.h"], hdrs = ["parallel_device_testlib.h"],
deps = [ deps = [
":parallel_device", ":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental", "//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
@ -118,7 +119,6 @@ tf_cc_test(
srcs = ["parallel_device_test.cc"], srcs = ["parallel_device_test.cc"],
deps = [ deps = [
":parallel_device", ":parallel_device",
":parallel_device_ops",
":parallel_device_testlib", ":parallel_device_testlib",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental", "//tensorflow/c:c_api_experimental",
@ -138,7 +138,6 @@ tf_cc_test(
args = ["--heap_check=local"], args = ["--heap_check=local"],
deps = [ deps = [
":parallel_device", ":parallel_device",
":parallel_device_ops",
":parallel_device_testlib", ":parallel_device_testlib",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental", "//tensorflow/c:c_api_experimental",
@ -150,19 +149,3 @@ tf_cc_test(
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
], ],
) )
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(
name = "parallel_device_ops_srcs",
srcs = ["parallel_device_ops.cc"],
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
)
cc_library(
name = "parallel_device_ops",
srcs = [":parallel_device_ops_srcs"],
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -136,13 +136,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
} }
result.emplace(std::move(outputs)); result.emplace(std::move(outputs));
return result; return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
} }
std::vector<ParallelTensor*> parallel_inputs; std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors; std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
@ -255,28 +248,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
// Since this function is used to satisfy the TFE_CustomDevice C API, // Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a // device_info is passed in using a C-style generic. It must always be a
// ParallelDevice. // ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs, void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status, TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) { void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, status);
if (*requested_placement == '\0') {
TF_SetStatus(
status, TF_INTERNAL,
"Ops must be placed on the parallel device explicitly, or their inputs "
"first un-packed. Got an un-placed op with an input placed on the "
"parallel device.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
NamedParallelDevice* named_device = NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info); reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs; std::vector<MaybeParallelTensorUnowned> typed_inputs;
int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
typed_inputs.reserve(num_inputs); typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
if (TF_GetCode(status) != TF_OK) return;
const char* tensor_handle_device = const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status); TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
if (named_device->name() == tensor_handle_device) { if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are // We assume that any tensors already placed on this device are
// ParallelTensors. // ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>( typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status))); TFE_TensorHandleDevicePointer(input, status)));
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
} else { } else {
typed_inputs.emplace_back(inputs[i]); typed_inputs.emplace_back(input);
} }
} }

View File

@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class DeviceThread { class DeviceThread {
public: public:
// Starts a background thread waiting for `StartExecute`. // Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device) explicit DeviceThread(const std::string& device, const bool is_async)
: status_(TF_NewStatus()), : status_(TF_NewStatus()),
device_(device), device_(device),
// If the context's default exector is set to async, re-using that in // If the context's default exector is set to async, re-using that in
@ -67,7 +67,7 @@ class DeviceThread {
// //
// TODO(allenl): We should have an async API that works with the // TODO(allenl): We should have an async API that works with the
// parallel device. // parallel device.
executor_(TFE_NewExecutor(/*is_async=*/false)), executor_(TFE_NewExecutor(is_async)),
op_(nullptr), op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread( thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute", tensorflow::ThreadOptions(), "parallel_device_execute",
@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
} }
} }
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices) ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
const bool is_async)
: underlying_devices_(devices) { : underlying_devices_(devices) {
device_threads_.reserve(devices.size()); device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) { for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back( device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str())); new DeviceThread(devices[device_index].c_str(), is_async));
} }
} }

View File

@ -49,7 +49,10 @@ class DeviceThread;
// placed on each underlying device. // placed on each underlying device.
class ParallelDevice { class ParallelDevice {
public: public:
explicit ParallelDevice(const std::vector<std::string>& devices); // Eager async execution is only supported when remote eager is not in use
// (b/157523095).
explicit ParallelDevice(const std::vector<std::string>& devices,
const bool is_async = false);
~ParallelDevice(); ~ParallelDevice();

View File

@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device); ASSERT_EQ(underlying_devices[1], second_device);
} }
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int32_t>(components[0].get(), 0);
ExpectScalarEq<int32_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
} }

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
@ -98,6 +99,10 @@ class VSpace {
gtl::ArraySlice<Gradient*> output_gradients, gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) const = 0; std::vector<Gradient*>* result) const = 0;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
virtual Status BuildOnesLike(const TapeTensor& t,
Gradient** result) const = 0;
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.
virtual int64 TensorId(Gradient* tensor) const = 0; virtual int64 TensorId(Gradient* tensor) const = 0;
@ -121,7 +126,7 @@ class GradientTape {
// functions (and hence the tensors they keep alive). Instead, everything // functions (and hence the tensors they keep alive). Instead, everything
// is deleted in ~GradientTape. Persistent GradientTapes are useful when // is deleted in ~GradientTape. Persistent GradientTapes are useful when
// users want to compute multiple gradients over the same tape. // users want to compute multiple gradients over the same tape.
GradientTape(bool persistent) : persistent_(persistent) {} explicit GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape() { ~GradientTape() {
for (const auto& pair : op_tape_) { for (const auto& pair : op_tape_) {
pair.second.backward_function_deleter(pair.second.backward_function); pair.second.backward_function_deleter(pair.second.backward_function);
@ -595,8 +600,10 @@ Status InitialGradients(
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
if (op_it->second.output_tensor_info[j].GetID() == id) { if (op_it->second.output_tensor_info[j].GetID() == id) {
found = true; found = true;
(*result)[id].push_back( Gradient* ones_like = nullptr;
op_it->second.output_tensor_info[j].OnesLike()); TF_RETURN_IF_ERROR(vspace.BuildOnesLike(
op_it->second.output_tensor_info[j], &ones_like));
(*result)[id].push_back(ones_like);
break; break;
} }
} }
@ -611,7 +618,10 @@ Status InitialGradients(
// target is also a source. // target is also a source.
auto source_tensor = sources_that_are_targets.find(id); auto source_tensor = sources_that_are_targets.find(id);
if (source_tensor != sources_that_are_targets.end()) { if (source_tensor != sources_that_are_targets.end()) {
(*result)[id].push_back(source_tensor->second.OnesLike()); Gradient* ones_like = nullptr;
TF_RETURN_IF_ERROR(
vspace.BuildOnesLike(source_tensor->second, &ones_like));
(*result)[id].push_back(ones_like);
} }
} }
} else { } else {
@ -934,7 +944,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
// TODO(allenl): Figure out why using zeros_like everywhere causes issues // TODO(allenl): Figure out why using zeros_like everywhere causes issues
// for some gradient functions and if there's another way to work around // for some gradient functions and if there's another way to work around
// it (e.g. conds instead of ifs). The value shouldn't really matter. // it (e.g. conds instead of ifs). The value shouldn't really matter.
aid = output_tensor.OnesLike(); TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid));
} }
if (TF_PREDICT_FALSE(aid == nullptr)) { if (TF_PREDICT_FALSE(aid == nullptr)) {
return tensorflow::errors::Internal( return tensorflow::errors::Internal(

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/tracing_utils.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace tracing {
Status MaybeSetOpName(AbstractOperation* op, const char* op_name) {
if (isa<TracingOperation>(op)) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(op)->SetOpName(op_name));
}
if (isa<gradients::TapeOperation>(op)) {
TF_RETURN_IF_ERROR(MaybeSetOpName(
dyn_cast<gradients::TapeOperation>(op)->GetBackingOperation(),
op_name));
}
return Status::OK();
}
} // namespace tracing
} // namespace tensorflow

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
#include "tensorflow/c/eager/abstract_operation.h"
namespace tensorflow {
namespace tracing {
Status MaybeSetOpName(AbstractOperation*, const char* op_name);
} // namespace tracing
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental filesystem C APIs for TensorFlow. # Experimental filesystem C APIs for TensorFlow.
# Will be moved in proper place once all filesystems are converted to the # Will be moved in proper place once all filesystems are converted to the
# modular framework. # modular framework.

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental gcs filesystem plugin. # Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
@ -29,6 +31,7 @@ cc_library(
":gcs_helper", ":gcs_helper",
":ram_file_block_cache", ":ram_file_block_cache",
"//tensorflow/c:env", "//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
@ -59,6 +62,7 @@ cc_library(
deps = [ deps = [
":cleanup", ":cleanup",
"//tensorflow/c:env", "//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "google/cloud/storage/client.h" #include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h" #include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments. // Implementation of a filesystem for GCS environments.
@ -120,20 +121,20 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
return -1; return -1;
} }
int64_t read; int64_t read;
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second, auto content_length = stream.headers().find("content-length");
&read)) { if (content_length == stream.headers().end()) {
// When we read a file with offset that is bigger than the actual file size. // When we read a file with offset that is bigger than the actual file size.
// GCS will return an empty header (e.g no `content-length` header). In this // GCS will return an empty header (e.g no `content-length` header). In this
// case, we will set read to `0` and continue. // case, we will set read to `0` and continue.
if (TF_GetCode(status) == TF_OUT_OF_RANGE) { read = 0;
read = 0; } else if (!absl::SimpleAtoi(content_length->second, &read)) {
} else { TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); return -1;
return -1;
}
} }
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here. // `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset,
read);
stream.read(buffer, read); stream.read(buffer, read);
read = stream.gcount(); read = stream.gcount();
if (read < buffer_size) { if (read < buffer_size) {
@ -146,6 +147,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
path, " @ ", offset) path, " @ ", offset)
.c_str()); .c_str());
} }
TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(),
offset);
} }
} }
return read; return read;
@ -259,7 +262,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
if (*offset == -1 || *offset == 0) { if (*offset == -1 || *offset == 0) {
// UploadFile will automatically switch to resumable upload based on Client // UploadFile will automatically switch to resumable upload based on Client
// configuration. // configuration.
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object); auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object,
gcs::Fields("size"));
if (!metadata) { if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
return; return;
@ -278,15 +282,18 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
} else { } else {
std::string temporary_object = std::string temporary_object =
gcs::CreateRandomPrefixName("tf_writable_file_gcs"); gcs::CreateRandomPrefixName("tf_writable_file_gcs");
auto metadata = auto metadata = gcs_client->UploadFile(outfile->getName(), bucket,
gcs_client->UploadFile(outfile->getName(), bucket, temporary_object); temporary_object, gcs::Fields(""));
if (!metadata) { if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
return; return;
} }
TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(),
temporary_object.c_str(), bucket.c_str(), object.c_str());
const std::vector<gcs::ComposeSourceObject> source_objects = { const std::vector<gcs::ComposeSourceObject> source_objects = {
{object, {}, {}}, {temporary_object, {}, {}}}; {object, {}, {}}, {temporary_object, {}, {}}};
metadata = gcs_client->ComposeObject(bucket, source_objects, object); metadata = gcs_client->ComposeObject(bucket, source_objects, object,
gcs::Fields("size"));
if (!metadata) { if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
return; return;
@ -321,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n,
"The internal temporary file is not writable."); "The internal temporary file is not writable.");
return; return;
} }
TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(),
gcs_file->object.c_str(), n);
gcs_file->sync_need = true; gcs_file->sync_need = true;
gcs_file->outfile.write(buffer, n); gcs_file->outfile.write(buffer, n);
if (!gcs_file->outfile) if (!gcs_file->outfile)
@ -346,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
void Flush(const TF_WritableFile* file, TF_Status* status) { void Flush(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file); auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->sync_need) { if (gcs_file->sync_need) {
TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (!gcs_file->outfile) { if (!gcs_file->outfile) {
TF_SetStatus(status, TF_INTERNAL, TF_SetStatus(status, TF_INTERNAL,
"Could not append to the internal temporary file."); "Could not append to the internal temporary file.");
@ -353,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
} }
SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset, SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
&gcs_file->outfile, gcs_file->gcs_client, status); &gcs_file->outfile, gcs_file->gcs_client, status);
TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
gcs_file->sync_need = false; gcs_file->sync_need = false;
} else { } else {
@ -361,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
} }
void Sync(const TF_WritableFile* file, TF_Status* status) { void Sync(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
Flush(file, status); Flush(file, status);
} }
void Close(const TF_WritableFile* file, TF_Status* status) { void Close(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file); auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (gcs_file->sync_need) { if (gcs_file->sync_need) {
Flush(file, status); Flush(file, status);
} }
@ -428,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) { if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
max_staleness = value; max_staleness = value;
} }
TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u",
max_bytes, block_size, max_staleness);
file_block_cache = std::make_unique<RamFileBlockCache>( file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness, block_size, max_bytes, max_staleness,
@ -504,13 +524,18 @@ void Cleanup(TF_Filesystem* filesystem) {
static void UncachedStatForObject(const std::string& bucket, static void UncachedStatForObject(const std::string& bucket,
const std::string& object, GcsFileStat* stat, const std::string& object, GcsFileStat* stat,
gcs::Client* gcs_client, TF_Status* status) { gcs::Client* gcs_client, TF_Status* status) {
auto metadata = gcs_client->GetObjectMetadata(bucket, object); auto metadata = gcs_client->GetObjectMetadata(
bucket, object, gcs::Fields("generation,size,timeStorageClassUpdated"));
if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status); if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status);
stat->generation_number = metadata->generation(); stat->generation_number = metadata->generation();
stat->base.length = metadata->size(); stat->base.length = metadata->size();
stat->base.mtime_nsec = stat->base.mtime_nsec =
metadata->time_storage_class_updated().time_since_epoch().count(); metadata->time_storage_class_updated().time_since_epoch().count();
stat->base.is_directory = object.back() == '/'; stat->base.is_directory = object.back() == '/';
TF_VLog(1,
"Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;",
bucket.c_str(), object.c_str(), stat->base.length,
stat->generation_number, stat->base.mtime_nsec);
return TF_SetStatus(status, TF_OK, ""); return TF_SetStatus(status, TF_OK, "");
} }
@ -545,9 +570,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
if (TF_GetCode(status) != TF_OK) return -1; if (TF_GetCode(status) != TF_OK) return -1;
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature( if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
path, stat.generation_number)) { path, stat.generation_number)) {
std::cout TF_VLog(
<< "File signature has been changed. Refreshing the cache. Path: " 1,
<< path; "File signature has been changed. Refreshing the cache. Path: %s",
path.c_str());
} }
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status); read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
} else { } else {
@ -579,6 +605,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
(gcs_file->compose ? 0 : -1)}); (gcs_file->compose ? 0 : -1)});
// We are responsible for freeing the pointer returned by TF_GetTempFileName // We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name); free(temp_file_name);
TF_VLog(3, "GcsWritableFile: %s", path);
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
@ -608,7 +635,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
} else { } else {
// If compose is true, we do not download anything. // If compose is true, we do not download anything.
// Instead we only check if this file exists on server or not. // Instead we only check if this file exists on server or not.
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object,
gcs::Fields("size"));
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
file->plugin_file = new tf_writable_file::GCSFile( file->plugin_file = new tf_writable_file::GCSFile(
@ -624,7 +652,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
return; return;
} }
} }
TF_VLog(3, "GcsWritableFile: %s with existing file %s", path,
temp_file_name.c_str());
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
@ -639,7 +668,8 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object,
gcs::Fields("size"));
if (!metadata) { if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
return; return;
@ -670,7 +700,8 @@ static void StatForObject(GCSFile* gcs_file, const std::string& path,
if (object.empty()) if (object.empty())
return TF_SetStatus( return TF_SetStatus(
status, TF_INVALID_ARGUMENT, status, TF_INVALID_ARGUMENT,
("'object' must be a non-empty string. (File: " + path + ")").c_str()); absl::StrCat("'object' must be a non-empty string. (File: ", path, ")")
.c_str());
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
gcs_file->stat_cache->LookupOrCompute( gcs_file->stat_cache->LookupOrCompute(
path, stat, path, stat,
@ -698,7 +729,8 @@ static bool ObjectExists(GCSFile* gcs_file, const std::string& path,
static bool BucketExists(GCSFile* gcs_file, const std::string& bucket, static bool BucketExists(GCSFile* gcs_file, const std::string& bucket,
TF_Status* status) { TF_Status* status) {
auto metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); auto metadata =
gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields(""));
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND) if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false; return false;
@ -721,7 +753,8 @@ static std::vector<std::string> GetChildrenBounded(
std::string delimiter = recursive ? "" : "/"; std::string delimiter = recursive ? "" : "/";
for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes( for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes(
bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter))) { bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter),
gcs::Fields("items(name),prefixes"))) {
if (count == max_results) { if (count == max_results) {
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
return result; return result;
@ -737,8 +770,8 @@ static std::vector<std::string> GetChildrenBounded(
auto pos = children.find(prefix); auto pos = children.find(prefix);
if (pos != 0) { if (pos != 0) {
TF_SetStatus(status, TF_INTERNAL, TF_SetStatus(status, TF_INTERNAL,
("Unexpected response: the returned file name " + children + absl::StrCat("Unexpected response: the returned file name ",
" doesn't match the prefix " + prefix) children, " doesn't match the prefix ", prefix)
.c_str()); .c_str());
return result; return result;
} }
@ -812,6 +845,10 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
std::string dir = path; std::string dir = path;
MaybeAppendSlash(&dir); MaybeAppendSlash(&dir);
TF_VLog(3,
"CreateDir: creating directory with path: %s and "
"path_with_slash: %s",
path, dir.c_str());
std::string bucket, object; std::string bucket, object;
ParseGCSPath(dir, true, &bucket, &object, status); ParseGCSPath(dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
@ -821,19 +858,23 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
if (!is_directory) if (!is_directory)
TF_SetStatus(status, TF_NOT_FOUND, TF_SetStatus(status, TF_NOT_FOUND,
("The specified bucket " + dir + " was not found.").c_str()); absl::StrCat("The specified bucket ", dir, " was not found.")
.c_str());
return; return;
} }
PathExists(filesystem, dir.c_str(), status); PathExists(filesystem, dir.c_str(), status);
if (TF_GetCode(status) == TF_OK) if (TF_GetCode(status) == TF_OK) {
// Use the original name for a correct error here.
TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path);
return TF_SetStatus(status, TF_ALREADY_EXISTS, path); return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
auto metadata = gcs_file->gcs_client.InsertObject( auto metadata = gcs_file->gcs_client.InsertObject(
bucket, object, "", bucket, object, "",
// Adding this parameter means HTTP_CODE_PRECONDITION_FAILED // Adding this parameter means HTTP_CODE_PRECONDITION_FAILED
// will be returned if the object already exists, so avoid reuploading. // will be returned if the object already exists, so avoid reuploading.
gcs::IfGenerationMatch(0)); gcs::IfGenerationMatch(0), gcs::Fields(""));
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
TF_SetStatus(status, TF_ALREADY_EXISTS, path); TF_SetStatus(status, TF_ALREADY_EXISTS, path);
@ -891,7 +932,8 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst); bucket_src, object_src, bucket_dst, object_dst,
gcs::Fields("done,rewriteToken"));
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
} }
@ -908,7 +950,8 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
if (!result) if (!result)
TF_SetStatus( TF_SetStatus(
status, TF_NOT_FOUND, status, TF_NOT_FOUND,
("The specified bucket gs://" + bucket + " was not found.").c_str()); absl::StrCat("The specified bucket gs://", bucket, " was not found.")
.c_str());
return result; return result;
} }
@ -933,6 +976,7 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
static void RenameObject(const TF_Filesystem* filesystem, static void RenameObject(const TF_Filesystem* filesystem,
const std::string& src, const std::string& dst, const std::string& src, const std::string& dst,
TF_Status* status) { TF_Status* status) {
TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str());
std::string bucket_src, object_src; std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status); ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
@ -943,9 +987,11 @@ static void RenameObject(const TF_Filesystem* filesystem,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst); bucket_src, object_src, bucket_dst, object_dst,
gcs::Fields("done,rewriteToken"));
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str());
ClearFileCaches(gcs_file, dst); ClearFileCaches(gcs_file, dst);
DeleteFile(filesystem, src.c_str(), status); DeleteFile(filesystem, src.c_str(), status);
@ -954,8 +1000,10 @@ static void RenameObject(const TF_Filesystem* filesystem,
void RenameFile(const TF_Filesystem* filesystem, const char* src, void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) { const char* dst, TF_Status* status) {
if (!IsDirectory(filesystem, src, status)) { if (!IsDirectory(filesystem, src, status)) {
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) if (TF_GetCode(status) == TF_FAILED_PRECONDITION) {
TF_SetStatus(status, TF_OK, "");
RenameObject(filesystem, src, dst, status); RenameObject(filesystem, src, dst, status);
}
return; return;
} }
@ -1032,7 +1080,8 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) { if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); auto bucket_metadata =
gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields(""));
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
stats->is_directory = true; stats->is_directory = true;
@ -1047,8 +1096,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
stats->mtime_nsec = 0; stats->mtime_nsec = 0;
return TF_SetStatus(status, TF_OK, ""); return TF_SetStatus(status, TF_OK, "");
} }
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_FAILED_PRECONDITION) {
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); auto metadata = gcs_file->gcs_client.GetObjectMetadata(
bucket, object, gcs::Fields("size,timeStorageClassUpdated"));
if (metadata) { if (metadata) {
stats->is_directory = false; stats->is_directory = false;
stats->length = metadata.value().size(); stats->length = metadata.value().size();
@ -1061,6 +1111,18 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
} }
} }
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
// Only validate the name.
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return -1;
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
return stat.length;
}
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri); return strdup(uri);
} }

View File

@ -87,6 +87,24 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path, const char* path,
TF_ReadOnlyMemoryRegion* region, TF_ReadOnlyMemoryRegion* region,
TF_Status* status); TF_Status* status);
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status);
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status);
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_Status* status);
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status);
} // namespace tf_gcs_filesystem } // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) #define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890"; static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890";
// We will work with content_view instead of content. // We will work with content_view instead of content.
@ -94,6 +95,70 @@ class GCSFilesystemTest : public ::testing::Test {
return translated_name; return translated_name;
} }
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
GetWriter() {
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
new TF_WritableFile, [](TF_WritableFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
delete file;
}
});
writer->plugin_file = nullptr;
return writer;
}
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
GetReader() {
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
reader->plugin_file = nullptr;
return reader;
}
void WriteString(const std::string& path, const std::string& content) {
auto writer = GetWriter();
tf_gcs_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Append(writer.get(), content.c_str(), content.length(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Close(writer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return;
}
std::string ReadAll(const std::string& path) {
auto reader = GetReader();
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
auto file_size =
tf_gcs_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
std::string content;
content.resize(file_size);
auto read = tf_random_access_file::Read(reader.get(), 0, file_size,
&content[0], status_);
if (TF_GetCode(status_) != TF_OK) return "";
if (read >= 0) content.resize(read);
if (file_size != content.size())
TF_SetStatus(
status_, TF_DATA_LOSS,
std::string("expected " + std::to_string(file_size) + " got " +
std::to_string(content.size()) + " bytes")
.c_str());
return content;
}
protected: protected:
TF_Filesystem* filesystem_; TF_Filesystem* filesystem_;
TF_Status* status_; TF_Status* status_;
@ -326,6 +391,145 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
delete region; delete region;
} }
TEST_F(GCSFilesystemTest, PathExists) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string path = GetURIForPath("PathExists");
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_);
TF_SetStatus(status_, TF_OK, "");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(GCSFilesystemTest, GetChildren) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string base = GetURIForPath("GetChildren");
tf_gcs_filesystem::CreateDir(filesystem_, base.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string file = io::JoinPath(base, "TestFile.csv");
WriteString(file, "test");
EXPECT_TF_OK(status_);
const std::string subdir = io::JoinPath(base, "SubDir");
tf_gcs_filesystem::CreateDir(filesystem_, subdir.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv");
WriteString(subfile, "test");
EXPECT_TF_OK(status_);
char** entries;
auto num_entries = tf_gcs_filesystem::GetChildren(filesystem_, base.c_str(),
&entries, status_);
EXPECT_TF_OK(status_);
std::vector<std::string> childrens;
for (int i = 0; i < num_entries; ++i) {
childrens.push_back(entries[i]);
}
std::sort(childrens.begin(), childrens.end());
EXPECT_EQ(std::vector<string>({"SubDir/", "TestFile.csv"}), childrens);
}
TEST_F(GCSFilesystemTest, DeleteFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string path = GetURIForPath("DeleteFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND);
}
TEST_F(GCSFilesystemTest, CreateDir) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string dir = GetURIForPath("CreateDir");
tf_gcs_filesystem::CreateDir(filesystem_, dir.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_TRUE(stat.is_directory);
}
TEST_F(GCSFilesystemTest, DeleteDir) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string dir = GetURIForPath("DeleteDir");
const std::string file = io::JoinPath(dir, "DeleteDirFile.csv");
WriteString(file, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
EXPECT_EQ(TF_GetCode(status_), TF_FAILED_PRECONDITION);
TF_SetStatus(status_, TF_OK, "");
tf_gcs_filesystem::DeleteFile(filesystem_, file.c_str(), status_);
EXPECT_TF_OK(status_);
tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
}
TEST_F(GCSFilesystemTest, StatFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string path = GetURIForPath("StatFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
TF_FileStatistics stat;
tf_gcs_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, stat.length);
EXPECT_FALSE(stat.is_directory);
}
TEST_F(GCSFilesystemTest, RenameFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string src = GetURIForPath("RenameFileSrc");
const std::string dst = GetURIForPath("RenameFileDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test", result);
}
TEST_F(GCSFilesystemTest, RenameFileOverwrite) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string src = GetURIForPath("RenameFileOverwriteSrc");
const std::string dst = GetURIForPath("RenameFileOverwriteDst");
WriteString(src, "test_old");
ASSERT_TF_OK(status_);
WriteString(dst, "test_new");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::PathExists(filesystem_, dst.c_str(), status_);
EXPECT_TF_OK(status_);
tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test_old", result);
}
// These tests below are ported from // These tests below are ported from
// `//tensorflow/core/platform/cloud:gcs_file_system_test` // `//tensorflow/core/platform/cloud:gcs_file_system_test`
TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) { TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) {

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h" #include "absl/synchronization/notification.h"
#include "tensorflow/c/env.h" #include "tensorflow/c/env.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
namespace tf_gcs_filesystem { namespace tf_gcs_filesystem {
@ -65,8 +66,8 @@ class RamFileBlockCache {
pruning_thread_.reset( pruning_thread_.reset(
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
} }
std::cout << "GCS file block cache is " TF_VLog(1, "GCS file block cache is %s.\n",
<< (IsCacheEnabled() ? "enabled" : "disabled") << ".\n"; (IsCacheEnabled() ? "enabled" : "disabled"));
} }
~RamFileBlockCache() { ~RamFileBlockCache() {

View File

@ -1,5 +1,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental hadoop filesystem plugin. # Experimental hadoop filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
package( package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
@ -20,12 +22,14 @@ cc_library(
name = "hadoop_filesystem_impl", name = "hadoop_filesystem_impl",
srcs = ["hadoop_filesystem.cc"], srcs = ["hadoop_filesystem.cc"],
hdrs = ["hadoop_filesystem.h"], hdrs = ["hadoop_filesystem.h"],
compatible_with = [],
copts = select({ copts = select({
"//conditions:default": [], "//conditions:default": [],
"//tensorflow:windows": get_win_copts(), "//tensorflow:windows": get_win_copts(),
}), }),
deps = [ deps = [
"//tensorflow/c:env", "//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"//third_party/hadoop:hdfs", "//third_party/hadoop:hdfs",
@ -33,3 +37,38 @@ cc_library(
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
], ],
) )
# This test is set to manual because it requires downloading the Hadoop
# distribution to run. To run this test:
# 1. Ensure $JAVA_HOME is set to the location of a JDK 8 installation.
# 2. Download the binary Hadoop distribution from:
# http://hadoop.apache.org/releases.html
# 3. Extract the Hadoop distribution and run:
# source libexec/hadoop-config.sh
# 4. Optionally set up HDFS cluster configurations (optionally Kerberos) within
# $HADOOP_HDFS_HOME/etc/hadoop if you want to test against real
# distributed HDFS cluster
# 5. bazel test \
# --test_env=LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/amd64/server \
# --test_env=HADOOP_HDFS_HOME=$HADOOP_HDFS_HOME \
# --test_env=CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) \
# :hadoop_file_system_test
# To test against the real distributed cluster, add the following option for
# bazel test:
# --test_env=HADOOP_TEST_TMPDIR=hdfs://cluster/test/tmp/dir
tf_cc_test(
name = "hadoop_filesystem_test",
srcs = [
"hadoop_filesystem_test.cc",
],
tags = [
"manual",
"notap",
],
deps = [
":hadoop_filesystem_impl",
"//tensorflow/core/platform:path",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
],
)

View File

@ -22,11 +22,10 @@ limitations under the License.
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/env.h" #include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
// Implementation of a filesystem for HADOOP environments. // Implementation of a filesystem for HADOOP environments.
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes. // This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
@ -37,11 +36,17 @@ static void plugin_memory_free(void* ptr) { free(ptr); }
void ParseHadoopPath(const std::string& fname, std::string* scheme, void ParseHadoopPath(const std::string& fname, std::string* scheme,
std::string* namenode, std::string* path) { std::string* namenode, std::string* path) {
size_t scheme_end = fname.find("://") + 2; size_t scheme_end = fname.find("://") + 2;
*scheme = fname.substr(0, scheme_end + 1); // We don't want `://` in scheme.
*scheme = fname.substr(0, scheme_end - 2);
size_t nn_end = fname.find("/", scheme_end + 1); size_t nn_end = fname.find("/", scheme_end + 1);
if (nn_end == std::string::npos) return; if (nn_end == std::string::npos) {
*namenode = fname.substr(scheme_end + 1);
*path = "";
return;
}
*namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1); *namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1);
*path = fname.substr(nn_end + 1); // We keep `/` in path.
*path = fname.substr(nn_end);
} }
void SplitArchiveNameAndPath(std::string* path, std::string* nn, void SplitArchiveNameAndPath(std::string* path, std::string* nn,
@ -54,7 +59,7 @@ void SplitArchiveNameAndPath(std::string* path, std::string* nn,
} }
// Case of hadoop archive. Namenode is the path to the archive. // Case of hadoop archive. Namenode is the path to the archive.
std::ostringstream namenodestream; std::ostringstream namenodestream;
namenodestream << "har://" << nn namenodestream << "har://" << *nn
<< path->substr(0, index_end_archive_name + 4); << path->substr(0, index_end_archive_name + 4);
*nn = namenodestream.str(); *nn = namenodestream.str();
path->erase(0, index_end_archive_name + 4); path->erase(0, index_end_archive_name + 4);
@ -143,15 +148,20 @@ class LibHDFS {
char* hdfs_home = getenv("HADOOP_HDFS_HOME"); char* hdfs_home = getenv("HADOOP_HDFS_HOME");
if (hdfs_home != nullptr) { if (hdfs_home != nullptr) {
auto JoinPath = [](std::string home, std::string lib) { auto JoinPath = [](std::string home, std::string lib) {
#if defined(_WIN32)
if (home.back() != '\\') home.push_back('\\');
return home + "lib\\native\\" + lib;
#else
if (home.back() != '/') home.push_back('/'); if (home.back() != '/') home.push_back('/');
return home + "lib/native/" + lib; return home + "lib/native/" + lib;
#endif
}; };
std::string path = JoinPath(hdfs_home, kLibHdfsDso); std::string path = JoinPath(hdfs_home, kLibHdfsDso);
TryLoadAndBind(path.c_str(), &handle_, status); TryLoadAndBind(path.c_str(), &handle_, status);
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
return; return;
} else { } else {
std::cerr << "HadoopFileSystem load error: " << TF_Message(status); TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
} }
} }
@ -163,13 +173,15 @@ class LibHDFS {
void* handle_; void* handle_;
}; };
// We rely on HDFS connection caching here. The HDFS client calls // We implement connection caching in Tensorflow, which can significantly
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection // improve performance. Fixes #43187
// internally. hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) { const std::string& path, TF_Status* status) {
auto libhdfs = hadoop_file->libhdfs;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
std::string cacheKey(scheme);
hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
if (scheme == "file") { if (scheme == "file") {
libhdfs->hdfsBuilderSetNameNode(builder, nullptr); libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
@ -194,15 +206,24 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
SplitArchiveNameAndPath(&path_har, &namenode, status); SplitArchiveNameAndPath(&path_har, &namenode, status);
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
cacheKey += namenode;
} else { } else {
libhdfs->hdfsBuilderSetNameNode( libhdfs->hdfsBuilderSetNameNode(
builder, namenode.empty() ? "default" : namenode.c_str()); builder, namenode.empty() ? "default" : namenode.c_str());
cacheKey += namenode;
} }
auto fs = libhdfs->hdfsBuilderConnect(builder); absl::MutexLock l(&hadoop_file->connection_cache_lock);
if (fs == nullptr) if (hadoop_file->connection_cache.find(cacheKey) ==
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); hadoop_file->connection_cache.end()) {
else auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
TF_SetStatus(status, TF_OK, ""); if (cacheFs == nullptr) {
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
return cacheFs;
}
hadoop_file->connection_cache[cacheKey] = cacheFs;
}
auto fs = hadoop_file->connection_cache[cacheKey];
TF_SetStatus(status, TF_OK, "");
return fs; return fs;
} }
@ -216,6 +237,7 @@ typedef struct HDFSFile {
LibHDFS* libhdfs; LibHDFS* libhdfs;
absl::Mutex mu; absl::Mutex mu;
hdfsFile handle ABSL_GUARDED_BY(mu); hdfsFile handle ABSL_GUARDED_BY(mu);
bool disable_eof_retried;
HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs, HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs,
hdfsFile handle) hdfsFile handle)
: path(std::move(path)), : path(std::move(path)),
@ -223,7 +245,15 @@ typedef struct HDFSFile {
fs(fs), fs(fs),
libhdfs(libhdfs), libhdfs(libhdfs),
mu(), mu(),
handle(handle) {} handle(handle) {
const char* disable_eof_retried_str =
getenv("HDFS_DISABLE_READ_EOF_RETRIED");
if (disable_eof_retried_str && disable_eof_retried_str[0] == '1') {
disable_eof_retried = true;
} else {
disable_eof_retried = false;
}
}
} HDFSFile; } HDFSFile;
void Cleanup(TF_RandomAccessFile* file) { void Cleanup(TF_RandomAccessFile* file) {
@ -247,8 +277,12 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* dst = buffer; char* dst = buffer;
bool eof_retried = false; bool eof_retried = false;
int64_t r = 0; if (hdfs_file->disable_eof_retried) {
while (TF_GetCode(status) == TF_OK && !eof_retried) { // eof_retried = true, avoid calling hdfsOpenFile in Read, Fixes #42597
eof_retried = true;
}
int64_t read = 0;
while (TF_GetCode(status) == TF_OK && n > 0) {
// We lock inside the loop rather than outside so we don't block other // We lock inside the loop rather than outside so we don't block other
// concurrent readers. // concurrent readers.
absl::MutexLock l(&hdfs_file->mu); absl::MutexLock l(&hdfs_file->mu);
@ -257,12 +291,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
// of int32. -2 offset can avoid JVM OutOfMemoryError. // of int32. -2 offset can avoid JVM OutOfMemoryError.
size_t read_n = size_t read_n =
(std::min)(n, static_cast<size_t>(std::numeric_limits<int>::max() - 2)); (std::min)(n, static_cast<size_t>(std::numeric_limits<int>::max() - 2));
r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset), dst, int64_t r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset),
static_cast<tSize>(read_n)); dst, static_cast<tSize>(read_n));
if (r > 0) { if (r > 0) {
dst += r; dst += r;
n -= r; n -= r;
offset += r; offset += r;
read += r;
} else if (!eof_retried && r == 0) { } else if (!eof_retried && r == 0) {
// Always reopen the file upon reaching EOF to see if there's more data. // Always reopen the file upon reaching EOF to see if there's more data.
// If writers are streaming contents while others are concurrently // If writers are streaming contents while others are concurrently
@ -274,11 +309,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
TF_SetStatusFromIOError(status, errno, path); TF_SetStatusFromIOError(status, errno, path);
return -1; return -1;
} }
handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0); hdfs_file->handle =
if (handle == nullptr) { libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0);
if (hdfs_file->handle == nullptr) {
TF_SetStatusFromIOError(status, errno, path); TF_SetStatusFromIOError(status, errno, path);
return -1; return -1;
} }
handle = hdfs_file->handle;
eof_retried = true; eof_retried = true;
} else if (eof_retried && r == 0) { } else if (eof_retried && r == 0) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
@ -288,7 +325,7 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
TF_SetStatusFromIOError(status, errno, path); TF_SetStatusFromIOError(status, errno, path);
} }
} }
return r; return read;
} }
} // namespace tf_random_access_file } // namespace tf_random_access_file
@ -308,7 +345,7 @@ typedef struct HDFSFile {
handle(handle) {} handle(handle) {}
} HDFSFile; } HDFSFile;
static void Cleanup(TF_WritableFile* file) { void Cleanup(TF_WritableFile* file) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file); auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle); hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
hdfs_file->fs = nullptr; hdfs_file->fs = nullptr;
@ -387,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` // SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_read_only_memory_region { namespace tf_read_only_memory_region {
// Hadoop doesn't support Readonly Memory Region
// TODO(vnvo2409): Implement later
} // namespace tf_read_only_memory_region } // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem // SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_hadoop_filesystem { namespace tf_hadoop_filesystem {
HadoopFile::HadoopFile(TF_Status* status)
: libhdfs(new LibHDFS(status)),
connection_cache_lock(),
connection_cache() {}
void Init(TF_Filesystem* filesystem, TF_Status* status) { void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new LibHDFS(status); filesystem->plugin_filesystem = new HadoopFile(status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
void Cleanup(TF_Filesystem* filesystem) { void Cleanup(TF_Filesystem* filesystem) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
delete libhdfs; delete libhdfs;
delete hadoop_file;
} }
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) { TF_RandomAccessFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -426,8 +469,27 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
void NewWritableFile(const TF_Filesystem* filesystem, const char* path, void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) { TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_WRONLY, 0, 0, 0);
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
file->plugin_file =
new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -458,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
void PathExists(const TF_Filesystem* filesystem, const char* path, void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -474,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path,
void Stat(const TF_Filesystem* filesystem, const char* path, void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) { TF_FileStatistics* stats, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -493,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return -1; if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -514,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
void DeleteFile(const TF_Filesystem* filesystem, const char* path, void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -529,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
void CreateDir(const TF_Filesystem* filesystem, const char* path, void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -544,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
void DeleteDir(const TF_Filesystem* filesystem, const char* path, void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -580,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
void RenameFile(const TF_Filesystem* filesystem, const char* src, void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) { const char* dst, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, src, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, src, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst; std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
@ -601,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
int GetChildren(const TF_Filesystem* filesystem, const char* path, int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) { char*** entries, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return -1; if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -638,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
return num_entries; return num_entries;
} }
// TODO(vnvo2409): Implement later static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
} // namespace tf_hadoop_filesystem } // namespace tf_hadoop_filesystem
@ -646,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) { const char* uri) {
TF_SetFilesystemVersionMetadata(ops); TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri); ops->scheme = strdup(uri);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_hadoop_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file =
tf_hadoop_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_hadoop_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
} }
void TF_InitPlugin(TF_FilesystemPluginInfo* info) { void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -15,7 +15,73 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#include <map>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
void ParseHadoopPath(const std::string& fname, std::string* scheme,
std::string* namenode, std::string* path);
void SplitArchiveNameAndPath(std::string* path, std::string* nn,
TF_Status* status);
class LibHDFS;
namespace tf_random_access_file {
void Cleanup(TF_RandomAccessFile* file);
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status);
} // namespace tf_random_access_file
namespace tf_writable_file {
void Cleanup(TF_WritableFile* file);
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status);
int64_t Tell(const TF_WritableFile* file, TF_Status* status);
void Sync(const TF_WritableFile* file, TF_Status* status);
void Flush(const TF_WritableFile* file, TF_Status* status);
void Close(const TF_WritableFile* file, TF_Status* status);
} // namespace tf_writable_file
namespace tf_hadoop_filesystem {
typedef struct HadoopFile {
LibHDFS* libhdfs;
absl::Mutex connection_cache_lock;
std::map<std::string, hdfsFS> connection_cache
ABSL_GUARDED_BY(connection_cache_lock);
HadoopFile(TF_Status* status);
} HadoopFile;
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status);
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status);
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status);
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status);
} // namespace tf_hadoop_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_

View File

@ -0,0 +1,460 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#include "third_party/hadoop/hdfs.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
namespace tensorflow {
namespace {
class HadoopFileSystemTest : public ::testing::Test {
public:
void SetUp() override {
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_hadoop_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
}
void TearDown() override {
TF_DeleteStatus(status_);
tf_hadoop_filesystem::Cleanup(filesystem_);
delete filesystem_;
}
std::string TmpDir(const std::string& path) {
char* test_dir = getenv("HADOOP_TEST_TMPDIR");
if (test_dir != nullptr) {
return io::JoinPath(std::string(test_dir), path);
} else {
return "file://" + io::JoinPath(testing::TmpDir(), path);
}
}
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
GetWriter() {
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
new TF_WritableFile, [](TF_WritableFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
delete file;
}
});
writer->plugin_file = nullptr;
return writer;
}
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
GetReader() {
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
reader->plugin_file = nullptr;
return reader;
}
void WriteString(const std::string& path, const std::string& content) {
auto writer = GetWriter();
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(),
writer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Append(writer.get(), content.c_str(), content.length(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Close(writer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return;
}
std::string ReadAll(const std::string& path) {
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
auto file_size =
tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
std::string content;
content.resize(file_size);
auto read = tf_random_access_file::Read(reader.get(), 0, file_size,
&content[0], status_);
if (TF_GetCode(status_) != TF_OK) return "";
if (read >= 0) content.resize(read);
if (file_size != content.size())
TF_SetStatus(
status_, TF_DATA_LOSS,
std::string("expected " + std::to_string(file_size) + " got " +
std::to_string(content.size()) + " bytes")
.c_str());
return content;
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
};
TEST_F(HadoopFileSystemTest, RandomAccessFile) {
const std::string path = TmpDir("RandomAccessFile");
const std::string content = "abcdefghijklmn";
WriteString(path, content);
ASSERT_TF_OK(status_);
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content.size(), result.size());
EXPECT_EQ(content, result);
result.clear();
result.resize(4);
read = tf_random_access_file::Read(reader.get(), 2, 4, &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, result.size());
EXPECT_EQ(content.substr(2, 4), result);
}
TEST_F(HadoopFileSystemTest, WritableFile) {
auto writer = GetWriter();
const std::string path = TmpDir("WritableFile");
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Append(writer.get(), "content1,", strlen("content1,"),
status_);
EXPECT_TF_OK(status_);
auto pos = tf_writable_file::Tell(writer.get(), status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(pos, 9);
tf_writable_file::Append(writer.get(), "content2", strlen("content2"),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Flush(writer.get(), status_);
EXPECT_TF_OK(status_);
tf_writable_file::Sync(writer.get(), status_);
EXPECT_TF_OK(status_);
tf_writable_file::Close(writer.get(), status_);
EXPECT_TF_OK(status_);
auto content = ReadAll(path);
EXPECT_TF_OK(status_);
EXPECT_EQ("content1,content2", content);
}
TEST_F(HadoopFileSystemTest, PathExists) {
const std::string path = TmpDir("PathExists");
tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_);
TF_SetStatus(status_, TF_OK, "");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(HadoopFileSystemTest, GetChildren) {
const std::string base = TmpDir("GetChildren");
tf_hadoop_filesystem::CreateDir(filesystem_, base.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string file = io::JoinPath(base, "TestFile.csv");
WriteString(file, "test");
EXPECT_TF_OK(status_);
const std::string subdir = io::JoinPath(base, "SubDir");
tf_hadoop_filesystem::CreateDir(filesystem_, subdir.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv");
WriteString(subfile, "test");
EXPECT_TF_OK(status_);
char** entries;
auto num_entries = tf_hadoop_filesystem::GetChildren(
filesystem_, base.c_str(), &entries, status_);
EXPECT_TF_OK(status_);
std::vector<std::string> childrens;
for (int i = 0; i < num_entries; ++i) {
childrens.push_back(entries[i]);
}
std::sort(childrens.begin(), childrens.end());
EXPECT_EQ(std::vector<string>({"SubDir", "TestFile.csv"}), childrens);
}
TEST_F(HadoopFileSystemTest, DeleteFile) {
const std::string path = TmpDir("DeleteFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(HadoopFileSystemTest, GetFileSize) {
const std::string path = TmpDir("GetFileSize");
WriteString(path, "test");
ASSERT_TF_OK(status_);
auto file_size =
tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, file_size);
}
TEST_F(HadoopFileSystemTest, CreateDirStat) {
const std::string path = TmpDir("CreateDirStat");
tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_TRUE(stat.is_directory);
}
TEST_F(HadoopFileSystemTest, DeleteDir) {
const std::string path = TmpDir("DeleteDir");
tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_);
EXPECT_NE(TF_GetCode(status_), TF_OK);
tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_NE(TF_GetCode(status_), TF_OK);
}
TEST_F(HadoopFileSystemTest, RenameFile) {
const std::string src = TmpDir("RenameFileSrc");
const std::string dst = TmpDir("RenameFileDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(),
status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test", result);
}
TEST_F(HadoopFileSystemTest, RenameFileOverwrite) {
const std::string src = TmpDir("RenameFileOverwriteSrc");
const std::string dst = TmpDir("RenameFileOverwriteDst");
WriteString(src, "test_old");
ASSERT_TF_OK(status_);
WriteString(dst, "test_new");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::PathExists(filesystem_, dst.c_str(), status_);
EXPECT_TF_OK(status_);
tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(),
status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test_old", result);
}
TEST_F(HadoopFileSystemTest, StatFile) {
const std::string path = TmpDir("StatFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
TF_FileStatistics stat;
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, stat.length);
EXPECT_FALSE(stat.is_directory);
}
TEST_F(HadoopFileSystemTest, WriteWhileReading) {
const std::string path = TmpDir("WriteWhileReading");
// Skip the test if we're not testing on HDFS. Hadoop's local filesystem
// implementation makes no guarantees that writable files are readable while
// being written.
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
auto writer = GetWriter();
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
EXPECT_TF_OK(status_);
const std::string content1 = "content1";
tf_writable_file::Append(writer.get(), content1.c_str(), content1.size(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Flush(writer.get(), status_);
EXPECT_TF_OK(status_);
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content1.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content1, result);
const std::string content2 = "content2";
tf_writable_file::Append(writer.get(), content2.c_str(), content2.size(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Flush(writer.get(), status_);
EXPECT_TF_OK(status_);
result.resize(content2.size());
read = tf_random_access_file::Read(reader.get(), content1.size(),
content2.size(), &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content2, result);
tf_writable_file::Close(writer.get(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) {
static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1";
putenv(set_disable_var);
const std::string path = TmpDir("ReadWhileOverwriting");
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
const string content1 = "content1";
WriteString(path, content1);
ASSERT_TF_OK(status_);
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content1.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content1, result);
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
string content2 = "overwrite";
WriteString(path, content1 + content2);
ASSERT_TF_OK(status_);
result.resize(content2.size());
read = tf_random_access_file::Read(reader.get(), content1.size(),
content2.size(), &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(0, result.size());
static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0";
putenv(set_enable_var);
}
TEST_F(HadoopFileSystemTest, HarSplit) {
const std::string har_path =
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
std::string scheme, namenode, path;
ParseHadoopPath(har_path, &scheme, &namenode, &path);
EXPECT_EQ("har", scheme);
EXPECT_EQ("hdfs-root", namenode);
EXPECT_EQ("/user/j.doe/my_archive.har/dir0/dir1/file.txt", path);
SplitArchiveNameAndPath(&path, &namenode, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode);
EXPECT_EQ("/dir0/dir1/file.txt", path);
}
TEST_F(HadoopFileSystemTest, NoHarExtension) {
const std::string har_path =
"har://hdfs-root/user/j.doe/my_archive/dir0/dir1/file.txt";
std::string scheme, namenode, path;
ParseHadoopPath(har_path, &scheme, &namenode, &path);
EXPECT_EQ("har", scheme);
EXPECT_EQ("hdfs-root", namenode);
EXPECT_EQ("/user/j.doe/my_archive/dir0/dir1/file.txt", path);
SplitArchiveNameAndPath(&path, &namenode, status_);
EXPECT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT) << TF_Message(status_);
}
TEST_F(HadoopFileSystemTest, HarRootPath) {
const std::string har_path = "har://hdfs-root/user/j.doe/my_archive.har";
std::string scheme, namenode, path;
ParseHadoopPath(har_path, &scheme, &namenode, &path);
EXPECT_EQ("har", scheme);
EXPECT_EQ("hdfs-root", namenode);
EXPECT_EQ("/user/j.doe/my_archive.har", path);
SplitArchiveNameAndPath(&path, &namenode, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode);
EXPECT_EQ("/", path);
}
TEST_F(HadoopFileSystemTest, WriteLargeFile) {
if (std::getenv("HADOOP_TEST_LARGE_FILE") != "1") GTEST_SKIP();
const std::string path = TmpDir("WriteLargeFile");
const size_t file_size =
static_cast<size_t>(std::numeric_limits<tSize>::max()) + 1024;
// Fake a test string.
std::string source(file_size, {});
for (size_t i = 0; i < file_size; ++i) source[i] = (i % 128);
WriteString(path, source);
ASSERT_TF_OK(status_);
auto result = ReadAll(path);
EXPECT_TF_OK(status_);
EXPECT_EQ(source, result);
}
// NewAppendableFile() is not testable. Local filesystem maps to
// ChecksumFileSystem in Hadoop, where appending is an unsupported operation.
} // namespace
} // namespace tensorflow
GTEST_API_ int main(int argc, char** argv) {
tensorflow::testing::InstallStacktraceHandler();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental posix filesystem plugin. # Experimental posix filesystem plugin.
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental windows filesystem plugin. # Experimental windows filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")

View File

@ -1,3 +1,6 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Library of gradient functions. # Library of gradient functions.
package( package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
@ -16,7 +19,7 @@ cc_library(
"//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients", "//tensorflow/c/eager:gradients_internal",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
], ],
) )
@ -31,14 +34,11 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
], ],
) )
@ -52,13 +52,46 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/eager:gradients", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "gradients",
hdrs = [
"array_grad.h",
"math_grad.h",
"nn_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_grad",
":math_grad",
":nn_grad",
"//tensorflow/c/eager:gradients_internal",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"array_grad.h",
"math_grad.h",
"nn_grad.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
], ],
) )

View File

@ -22,10 +22,10 @@ limitations under the License.
using std::vector; using std::vector;
using tensorflow::ops::Conj; using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::MatMul; using tensorflow::ops::MatMul;
using tensorflow::ops::Mul; using tensorflow::ops::Mul;
using tensorflow::ops::ZerosLike; using tensorflow::ops::Neg;
using tensorflow::ops::SqrtGrad;
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
@ -36,21 +36,14 @@ class AddGradientFunction : public GradientFunction {
Status Compute(Context* ctx, const IncomingGradients& grad_inputs, Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override { vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2); grad_outputs->resize(2);
vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting. // TODO(b/161805092): Support broadcasting.
std::string name = "Identity_A"; DCHECK(grad_inputs[0]);
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, (*grad_outputs)[0] = grad_inputs[0];
absl::MakeSpan(identity_outputs), (*grad_outputs)[1] = grad_inputs[0];
name.c_str()));
(*grad_outputs)[0] = identity_outputs[0];
name = "Identity_B"; (*grad_outputs)[0]->Ref();
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, (*grad_outputs)[1]->Ref();
absl::MakeSpan(identity_outputs),
name.c_str()));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK(); return Status::OK();
} }
~AddGradientFunction() override {} ~AddGradientFunction() override {}
@ -81,6 +74,25 @@ class ExpGradientFunction : public GradientFunction {
AbstractTensorHandlePtr exp_; AbstractTensorHandlePtr exp_;
}; };
class SqrtGradientFunction : public GradientFunction {
public:
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
sqrt->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
std::string name = "Sqrt_Grad";
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~SqrtGradientFunction() override {}
private:
AbstractTensorHandlePtr sqrt_;
};
class MatMulGradientFunction : public GradientFunction { class MatMulGradientFunction : public GradientFunction {
public: public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs, explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
@ -190,6 +202,56 @@ class MatMulGradientFunction : public GradientFunction {
AttrBuilder forward_attrs; AttrBuilder forward_attrs;
}; };
class NegGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a Neg op Y = -X, the gradients are:
*
* dX = -U
*
*/
grad_outputs->resize(1);
std::string name = "Neg_Grad";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~NegGradientFunction() override {}
};
class SubGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a Sub op A-B, the gradients are:
*
* dA = U
* dB = -U
*
*/
grad_outputs->resize(2);
// Grad for A
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
(*grad_outputs)[0]->Ref();
// Grad for B
// negate the upstream grad
std::vector<AbstractTensorHandle*> neg_outputs(1);
std::string name = "Neg_Sub_Grad_B";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(neg_outputs), name.c_str()));
(*grad_outputs)[1] = neg_outputs[0];
return Status::OK();
}
~SubGradientFunction() override {}
};
} // namespace } // namespace
BackwardFunction* AddRegisterer(const ForwardOperation& op) { BackwardFunction* AddRegisterer(const ForwardOperation& op) {
@ -219,5 +281,32 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients); return new BackwardFunction(gradient_function, default_gradients);
} }
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* NegRegisterer(const ForwardOperation& op) {
auto gradient_function = new NegGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new SubGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow

View File

@ -19,10 +19,15 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
BackwardFunction* AddRegisterer(const ForwardOperation& op); BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op); BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
BackwardFunction* NegRegisterer(const ForwardOperation& op);
BackwardFunction* SubRegisterer(const ForwardOperation& op);
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -14,17 +14,19 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/experimental/gradients/nn_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
using std::vector; using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::Mul; using tensorflow::ops::Mul;
using tensorflow::ops::ReluGrad; using tensorflow::ops::ReluGrad;
using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
using tensorflow::ops::ZerosLike;
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
@ -58,9 +60,31 @@ class ReluGradientFunction : public GradientFunction {
vector<AbstractTensorHandle*> forward_outputs; vector<AbstractTensorHandle*> forward_outputs;
}; };
class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction { Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec,
AbstractTensorHandle* mat,
absl::Span<AbstractTensorHandle*> outputs) {
if (!isa<ImmediateExecutionContext>(ctx)) {
// TODO(b/168850692): Fix this.
return errors::Unimplemented(
"BroadcastMul is not supported in tracing mode yet.");
}
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(ctx);
AbstractTensorPtr minus_1(imm_ctx->CreateInt32Scalar(-1));
ImmediateTensorHandlePtr dim(imm_ctx->CreateLocalHandle(minus_1.get()));
vector<AbstractTensorHandle*> expand_dims_outputs(1);
TF_RETURN_IF_ERROR(ops::ExpandDims(ctx, {vec, dim.get()},
absl::MakeSpan(expand_dims_outputs),
"ExpandDims"));
TF_RETURN_IF_ERROR(
ops::Mul(ctx, {expand_dims_outputs[0], mat}, outputs, "Mul"));
expand_dims_outputs[0]->Unref();
return Status::OK();
}
class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
: public GradientFunction {
public: public:
explicit SparseSoftmaxCrossEntropyLossGradientFunction( explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction(
vector<AbstractTensorHandle*> f_outputs) vector<AbstractTensorHandle*> f_outputs)
: forward_outputs(f_outputs) {} : forward_outputs(f_outputs) {}
@ -69,12 +93,10 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
grad_outputs->resize(2); grad_outputs->resize(2);
// Grad for Softmax Input // Grad for Softmax Input
std::string name = "Mul_Softmax_Grad";
vector<AbstractTensorHandle*> mul_outputs(1); vector<AbstractTensorHandle*> mul_outputs(1);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(BroadcastMul(
ops::Mul(ctx->ctx, {grad_inputs[0], forward_outputs[1]}, ctx->ctx, grad_inputs[0], forward_outputs[1],
absl::MakeSpan(mul_outputs), absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad
name.c_str())); // upstream_grad * local softmax grad
(*grad_outputs)[0] = mul_outputs[0]; (*grad_outputs)[0] = mul_outputs[0];
// Grad for labels is null // Grad for labels is null
@ -82,7 +104,7 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
return Status::OK(); return Status::OK();
} }
~SparseSoftmaxCrossEntropyLossGradientFunction() override {} ~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {}
private: private:
vector<AbstractTensorHandle*> forward_outputs; vector<AbstractTensorHandle*> forward_outputs;
@ -99,10 +121,10 @@ BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients); return new BackwardFunction(gradient_function, default_gradients);
} }
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op) { const ForwardOperation& op) {
auto gradient_function = auto gradient_function =
new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs); new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
auto default_gradients = new PassThroughDefaultGradients(op); auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients); return new BackwardFunction(gradient_function, default_gradients);
} }

View File

@ -20,9 +20,9 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
BackwardFunction* ReluRegisterer(const ForwardOperation& op); BackwardFunction* ReluRegisterer(const ForwardOperation& op);
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op); const ForwardOperation& op);
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_

View File

@ -0,0 +1,66 @@
# A tape built on top of unified execution APIs.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "tape_context",
srcs = ["tape_context.cc"],
hdrs = [
"tape_context.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tape_operation",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/c/eager:abstract_operation",
],
)
cc_library(
name = "tape_operation",
srcs = ["tape_operation.cc"],
hdrs = [
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:gradients_internal",
],
)
cc_library(
name = "tape",
hdrs = [
"tape_context.h",
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tape_context",
":tape_operation",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"tape_context.h",
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
)

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
namespace tensorflow {
namespace gradients {
TapeContext::TapeContext(AbstractContext* c, Tape* tape,
const GradientRegistry& registry)
: AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) {
// TODO(srbs): Make AbstractContext ref counted.
// parent_ctx_->Ref();
}
void TapeContext::Release() {
// TODO(srbs): Change to Unref()
delete this;
}
TapeContext::~TapeContext() {
// TODO(srbs): Make AbstractContext ref counted.
// parent_ctx_->Unref();
}
TapeOperation* TapeContext::CreateOperation() {
return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_);
}
Status TapeContext::RegisterFunction(AbstractFunction* f) {
return parent_ctx_->RegisterFunction(f);
}
Status TapeContext::RemoveFunction(const string& func) {
return parent_ctx_->RemoveFunction(func);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
namespace tensorflow {
namespace gradients {
class TapeContext : public AbstractContext {
public:
explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&);
void Release() override;
TapeOperation* CreateOperation() override;
Status RegisterFunction(AbstractFunction*) override;
Status RemoveFunction(const string& func) override;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kTape;
}
~TapeContext() override;
private:
AbstractContext* parent_ctx_; // Not owned.
Tape* tape_;
const GradientRegistry& registry_;
};
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_

View File

@ -0,0 +1,238 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
const GradientRegistry& registry)
: AbstractOperation(kTape),
parent_op_(parent_op),
tape_(tape),
registry_(registry) {
// TODO(srbs): Make AbstractOperation RefCounted.
// parent_op_->Ref();
}
void TapeOperation::Release() {
// TODO(srbs): Change to Unref().
delete this;
}
TapeOperation::~TapeOperation() {
// TODO(srbs): Make AbstractOperation RefCounted.
// parent_op->Unref();
}
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
forward_op_.op_name = op;
forward_op_.attrs.Reset(op);
forward_op_.inputs.clear();
forward_op_.outputs.clear();
return parent_op_->Reset(op, raw_device_name);
}
const string& TapeOperation::Name() const { return parent_op_->Name(); }
const string& TapeOperation::DeviceName() const {
return parent_op_->DeviceName();
}
Status TapeOperation::SetDeviceName(const char* name) {
return parent_op_->SetDeviceName(name);
}
Status TapeOperation::AddInput(AbstractTensorHandle* input) {
TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
forward_op_.inputs.push_back(input);
return Status::OK();
}
Status TapeOperation::AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) {
TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs));
for (auto input : inputs) {
forward_op_.inputs.push_back(input);
}
return Status::OK();
}
Status TapeOperation::SetAttrString(const char* attr_name, const char* data,
size_t length) {
forward_op_.attrs.Set(attr_name, StringPiece(data, length));
return parent_op_->SetAttrString(attr_name, data, length);
}
Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) {
forward_op_.attrs.Set(attr_name, static_cast<int64>(value));
return parent_op_->SetAttrInt(attr_name, value);
}
Status TapeOperation::SetAttrFloat(const char* attr_name, float value) {
forward_op_.attrs.Set(attr_name, value);
return parent_op_->SetAttrFloat(attr_name, value);
}
Status TapeOperation::SetAttrBool(const char* attr_name, bool value) {
forward_op_.attrs.Set(attr_name, value);
return parent_op_->SetAttrBool(attr_name, value);
}
Status TapeOperation::SetAttrType(const char* attr_name, DataType value) {
forward_op_.attrs.Set(attr_name, value);
return parent_op_->SetAttrType(attr_name, value);
}
Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
forward_op_.attrs.Set(attr_name, proto);
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
}
Status TapeOperation::SetAttrFunction(const char* attr_name,
const AbstractOperation* value) {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status TapeOperation::SetAttrFunctionName(const char* attr_name,
const char* value, size_t length) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented "
"yet.");
}
Status TapeOperation::SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status TapeOperation::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
forward_op_.attrs.Set(attr_name, v);
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status TapeOperation::SetAttrFloatList(const char* attr_name,
const float* values, int num_values) {
forward_op_.attrs.Set(attr_name,
gtl::ArraySlice<const float>(values, num_values));
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrIntList(const char* attr_name,
const int64_t* values, int num_values) {
forward_op_.attrs.Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return parent_op_->SetAttrIntList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrTypeList(const char* attr_name,
const DataType* values, int num_values) {
forward_op_.attrs.Set(attr_name,
gtl::ArraySlice<const DataType>(values, num_values));
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
forward_op_.attrs.Set(attr_name,
gtl::ArraySlice<const bool>(b.get(), num_values));
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims, int num_values) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
forward_op_.attrs.Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status TapeOperation::SetAttrFunctionList(
const char* attr_name, absl::Span<const AbstractOperation*> values) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been "
"implemented yet.");
}
AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
std::vector<int64> input_ids(forward_op_.inputs.size());
std::vector<tensorflow::DataType> input_dtypes(forward_op_.inputs.size());
for (int i = 0; i < forward_op_.inputs.size(); i++) {
input_ids[i] = ToId(forward_op_.inputs[i]);
input_dtypes[i] = forward_op_.inputs[i]->DataType();
}
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_.outputs.push_back(retvals[i]);
}
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
// attributes. Number type attrs and DataType attrs work fine without this.
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_.attrs.BuildNodeDef();
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t));
}
tape_->RecordOperation(
parent_op_->Name(), tape_tensors, input_ids, input_dtypes,
[this]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry_.Lookup(forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return backward_fn.release();
},
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}
});
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,80 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
class TapeOperation : public AbstractOperation {
public:
explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&);
void Release() override;
Status Reset(const char* op, const char* raw_device_name) override;
const string& Name() const override;
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override;
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override;
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperation*> values) override;
AbstractOperation* GetBackingOperation();
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kTape;
}
~TapeOperation() override;
private:
AbstractOperation* parent_op_;
ForwardOperation forward_op_;
Tape* tape_;
const GradientRegistry& registry_;
};
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_

View File

@ -1,3 +1,6 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental ops. These will eventually be replaced by machine-generated versions. # Experimental ops. These will eventually be replaced by machine-generated versions.
package( package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
@ -19,7 +22,7 @@ cc_library(
"//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/c/eager:tracing_utils",
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
], ],
) )
@ -40,8 +43,8 @@ cc_library(
"//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:tracing_utils",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
], ],
) )
@ -61,7 +64,41 @@ cc_library(
"//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/c/eager:tracing_utils",
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
], ],
) )
cc_library(
name = "ops",
hdrs = [
"array_ops.h",
"math_ops.h",
"nn_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_ops",
":math_ops",
":nn_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"array_ops.h",
"math_ops.h",
"nn_ops.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -14,9 +14,11 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/tracing_utils.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
using tensorflow::tracing::MaybeSetOpName;
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
@ -26,28 +28,58 @@ Status Identity(AbstractContext* ctx,
AbstractOperationPtr identity_op(ctx->CreateOperation()); AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr)); identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) { TF_RETURN_IF_ERROR(MaybeSetOpName(identity_op.get(), name));
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1; int num_retvals = 1;
return identity_op->Execute(outputs, &num_retvals); return identity_op->Execute(outputs, &num_retvals);
} }
Status IdentityN(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_n_op->Reset("IdentityN", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(identity_n_op.get(), name));
TF_RETURN_IF_ERROR(identity_n_op->AddInputList(inputs));
int num_retvals = inputs.size();
return identity_n_op->Execute(outputs, &num_retvals);
}
Status ZerosLike(AbstractContext* ctx, Status ZerosLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) { absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr z_op(ctx->CreateOperation()); AbstractOperationPtr z_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(z_op.get())) { TF_RETURN_IF_ERROR(MaybeSetOpName(z_op.get(), name));
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(z_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0]));
int num_retvals = 1; int num_retvals = 1;
return z_op->Execute(outputs, &num_retvals); return z_op->Execute(outputs, &num_retvals);
} }
Status Shape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr shape_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(shape_op.get(), name));
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input
int num_retvals = 1;
TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status ExpandDims(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("ExpandDims", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name));
TF_RETURN_IF_ERROR(op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(op->AddInput(inputs[1]));
int num_retvals = 1;
return op->Execute(outputs, &num_retvals);
}
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
@ -27,10 +26,22 @@ Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status IdentityN(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status ZerosLike(AbstractContext* ctx, Status ZerosLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Shape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status ExpandDims(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,22 +16,21 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/tracing_utils.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
using tensorflow::tracing::MaybeSetOpName;
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
using tensorflow::tracing::TracingOperation;
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) { absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr mul_op(ctx->CreateOperation()); AbstractOperationPtr mul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(mul_op.get())) { TF_RETURN_IF_ERROR(MaybeSetOpName(mul_op.get(), name));
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1])); TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
int num_retvals = 1; int num_retvals = 1;
@ -55,12 +54,7 @@ Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) { absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr add_op(ctx->CreateOperation()); AbstractOperationPtr add_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(add_op.get(), name));
if (isa<tracing::TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1])); TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
@ -69,18 +63,26 @@ Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return Status::OK(); return Status::OK();
} }
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sub_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sub_op.get(), name));
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status MatMul(AbstractContext* ctx, Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name, absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a = false, bool transpose_b = false) { bool transpose_a = false, bool transpose_b = false) {
AbstractOperationPtr matmul_op(ctx->CreateOperation()); AbstractOperationPtr matmul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(matmul_op.get(), name));
if (isa<tracing::TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1])); TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
@ -96,15 +98,79 @@ Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) { absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr neg_op(ctx->CreateOperation()); AbstractOperationPtr neg_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(neg_op.get())) { TF_RETURN_IF_ERROR(MaybeSetOpName(neg_op.get(), name));
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(neg_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0])); TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
int num_retvals = 1; int num_retvals = 1;
return neg_op->Execute(outputs, &num_retvals); return neg_op->Execute(outputs, &num_retvals);
} }
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sum_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sum_op.get(), name));
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices
int num_retvals = 1;
TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status DivNoNan(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr div_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
int num_retvals = 1;
TF_RETURN_IF_ERROR(div_op->Execute(
outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0)
return Status::OK();
}
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr exp_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(exp_op->Reset("Exp", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(exp_op.get(), name));
TF_RETURN_IF_ERROR(exp_op->AddInput(inputs[0]));
int num_retvals = 1;
return exp_op->Execute(outputs, &num_retvals);
}
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
int num_retvals = 1;
Status s = sqrt_op->Execute(outputs, &num_retvals);
return s;
}
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
int num_retvals = 1;
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
return s;
}
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -22,18 +22,43 @@ namespace tensorflow {
namespace ops { namespace ops {
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Conj(AbstractContext* ctx, Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status MatMul(AbstractContext* ctx, Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name, absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b); bool transpose_a, bool transpose_b);
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status DivNoNan(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,24 +15,22 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/eager/tracing_utils.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
using tensorflow::tracing::MaybeSetOpName;
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
// Softmax Loss given scores and labels, used by the SoftMaxLossGradient // Softmax Loss given scores and labels, used by the SoftMaxLossGradient
Status SparseSoftmaxCrossEntropyLoss( Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) { absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sm_loss_op(ctx->CreateOperation()); AbstractOperationPtr sm_loss_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits", TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr)); /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sm_loss_op.get(), name));
if (isa<tracing::TracingOperation>(sm_loss_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(sm_loss_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels
@ -49,12 +47,7 @@ Status ReluGrad(AbstractContext* ctx,
AbstractOperationPtr relugrad_op(ctx->CreateOperation()); AbstractOperationPtr relugrad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr)); relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(relugrad_op.get(), name));
if (isa<tracing::TracingOperation>(relugrad_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(relugrad_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs
@ -63,5 +56,18 @@ Status ReluGrad(AbstractContext* ctx,
return Status::OK(); return Status::OK();
} }
Status Relu(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(relu_op->Reset("Relu", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(relu_op.get(), name));
TF_RETURN_IF_ERROR(relu_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(relu_op->Execute(outputs, &num_retvals));
return Status::OK();
}
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,12 +18,11 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
Status SparseSoftmaxCrossEntropyLoss( Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
@ -31,6 +30,10 @@ Status ReluGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Relu(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental SavedModel C APIs for TensorFlow. See RFC # Experimental SavedModel C APIs for TensorFlow. See RFC
# https://github.com/tensorflow/community/pull/207 # https://github.com/tensorflow/community/pull/207
# Targets in this directory are pure C++ "Classes" underlying the C API types # Targets in this directory are pure C++ "Classes" underlying the C API types
@ -62,13 +64,21 @@ cc_library(
":function_metadata", ":function_metadata",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant", "//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
"//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable", "//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
], ],
) )
@ -81,15 +91,24 @@ cc_library(
":signature_def_function_metadata", ":signature_def_function_metadata",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
cc_library( cc_library(
name = "signature_def_function_metadata", name = "signature_def_function_metadata",
srcs = [
"signature_def_function_metadata.cc",
],
hdrs = [ hdrs = [
"signature_def_function_metadata.h", "signature_def_function_metadata.h",
], ],
deps = [
":tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
) )
cc_library( cc_library(
@ -138,11 +157,13 @@ cc_library(
":saved_model_api", ":saved_model_api",
":saved_model_utils", ":saved_model_utils",
":signature_def_function", ":signature_def_function",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops", "//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant", "//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
"//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects",
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable", "//tensorflow/c/experimental/saved_model/core/revived_types:variable",
@ -151,7 +172,6 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
@ -213,6 +233,7 @@ tf_cc_test(
"//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:core",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
], ],
) )
@ -256,6 +277,20 @@ tf_cc_test(
], ],
) )
cc_library(
name = "tensor_spec",
srcs = [
"tensor_spec.cc",
],
hdrs = [
"tensor_spec.h",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test( tf_cc_test(
name = "tf_concrete_function_loading_test", name = "tf_concrete_function_loading_test",
srcs = [ srcs = [

View File

@ -43,8 +43,8 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = default; virtual ~ConcreteFunction() = default;
// This method returns the "Call" Op used to execute the function. // This method returns the "Call" Op used to execute the function.
virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs, virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) = 0; ImmediateOpPtr* out) const = 0;
virtual const FunctionMetadata& GetFunctionMetadata() const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
}; };

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
@ -300,80 +301,70 @@ nodes {
TEST(ObjectGraphTraversalTest, Success) { TEST(ObjectGraphTraversalTest, Success) {
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
const SavedObject* obj = internal::FindNodeAtPath("foo", object_graph); absl::optional<int> node = internal::FindNodeAtPath("foo", object_graph);
ASSERT_NE(nullptr, obj); ASSERT_TRUE(node.has_value());
EXPECT_EQ(obj->kind_case(), SavedObject::kUserObject); EXPECT_EQ(*node, 1);
EXPECT_EQ(obj->user_object().identifier(), "_generic_user_object");
} }
TEST(ObjectGraphTraversalTest, ObjectNotFound) { TEST(ObjectGraphTraversalTest, ObjectNotFound) {
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
const SavedObject* obj = internal::FindNodeAtPath("bar", object_graph); absl::optional<int> node = internal::FindNodeAtPath("bar", object_graph);
EXPECT_EQ(nullptr, obj); EXPECT_FALSE(node.has_value());
} }
TEST(ObjectGraphTraversalTest, CaseSensitiveMismatch) { TEST(ObjectGraphTraversalTest, CaseSensitiveMismatch) {
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
const SavedObject* obj = internal::FindNodeAtPath("FOO", object_graph); absl::optional<int> node = internal::FindNodeAtPath("FOO", object_graph);
EXPECT_EQ(nullptr, obj); EXPECT_FALSE(node.has_value());
} }
TEST(ObjectGraphTraversalTest, NestedObjectFound) { TEST(ObjectGraphTraversalTest, NestedObjectFound) {
SavedObjectGraph object_graph = SavedObjectGraph object_graph =
ParseSavedObjectGraph(kSingleChildFooWithFuncBar); ParseSavedObjectGraph(kSingleChildFooWithFuncBar);
const SavedObject* obj = internal::FindNodeAtPath("foo.bar", object_graph); absl::optional<int> node = internal::FindNodeAtPath("foo.bar", object_graph);
ASSERT_NE(nullptr, obj); ASSERT_TRUE(node.has_value());
EXPECT_EQ(obj->kind_case(), SavedObject::kFunction); EXPECT_EQ(*node, 2);
EXPECT_EQ(obj->function().concrete_functions_size(), 1);
EXPECT_EQ(obj->function().concrete_functions(0), "__inference_my_func_5");
} }
TEST(ObjectGraphTraversalTest, MultiplePathsAliasSameObject) { TEST(ObjectGraphTraversalTest, MultiplePathsAliasSameObject) {
SavedObjectGraph object_graph = ParseSavedObjectGraph(kMultiplePathsToChild); SavedObjectGraph object_graph = ParseSavedObjectGraph(kMultiplePathsToChild);
const SavedObject* foo_baz = absl::optional<int> foo_baz_node =
internal::FindNodeAtPath("foo.baz", object_graph); internal::FindNodeAtPath("foo.baz", object_graph);
ASSERT_NE(nullptr, foo_baz); ASSERT_TRUE(foo_baz_node.has_value());
EXPECT_EQ(foo_baz->kind_case(), SavedObject::kUserObject); EXPECT_EQ(*foo_baz_node, 4);
EXPECT_EQ(foo_baz->user_object().identifier(), "_generic_user_object");
const SavedObject* bar_wombat = absl::optional<int> bar_wombat_node =
internal::FindNodeAtPath("bar.wombat", object_graph); internal::FindNodeAtPath("bar.wombat", object_graph);
ASSERT_NE(nullptr, bar_wombat); ASSERT_TRUE(bar_wombat_node.has_value());
EXPECT_EQ(bar_wombat->kind_case(), SavedObject::kUserObject); EXPECT_EQ(*bar_wombat_node, 4);
EXPECT_EQ(bar_wombat->user_object().identifier(), "_generic_user_object");
EXPECT_EQ(foo_baz, bar_wombat); EXPECT_EQ(*foo_baz_node, *bar_wombat_node);
} }
TEST(ObjectGraphTraversalTest, CyclesAreOK) { TEST(ObjectGraphTraversalTest, CyclesAreOK) {
SavedObjectGraph object_graph = SavedObjectGraph object_graph =
ParseSavedObjectGraph(kCycleBetweenParentAndChild); ParseSavedObjectGraph(kCycleBetweenParentAndChild);
const SavedObject* foo = internal::FindNodeAtPath("foo", object_graph); absl::optional<int> foo = internal::FindNodeAtPath("foo", object_graph);
ASSERT_NE(nullptr, foo); ASSERT_TRUE(foo.has_value());
EXPECT_EQ(foo->kind_case(), SavedObject::kUserObject); EXPECT_EQ(*foo, 1);
EXPECT_EQ(foo->user_object().identifier(), "_generic_user_object");
const SavedObject* foo_bar = absl::optional<int> foo_bar =
internal::FindNodeAtPath("foo.bar", object_graph); internal::FindNodeAtPath("foo.bar", object_graph);
ASSERT_NE(nullptr, foo_bar); ASSERT_TRUE(foo_bar.has_value());
EXPECT_EQ(foo_bar->kind_case(), SavedObject::kUserObject); EXPECT_EQ(*foo_bar, 3);
EXPECT_EQ(foo_bar->user_object().identifier(), "_generic_user_object");
const SavedObject* foo_bar_parent = absl::optional<int> foo_bar_parent =
internal::FindNodeAtPath("foo.bar.parent", object_graph); internal::FindNodeAtPath("foo.bar.parent", object_graph);
ASSERT_NE(nullptr, foo_bar_parent); ASSERT_TRUE(foo_bar_parent.has_value());
EXPECT_EQ(foo_bar_parent->kind_case(), SavedObject::kUserObject); EXPECT_EQ(*foo_bar_parent, 1);
EXPECT_EQ(foo_bar_parent->user_object().identifier(), "_generic_user_object");
const SavedObject* foo_bar_parent_bar = absl::optional<int> foo_bar_parent_bar =
internal::FindNodeAtPath("foo.bar.parent.bar", object_graph); internal::FindNodeAtPath("foo.bar.parent.bar", object_graph);
ASSERT_NE(nullptr, foo_bar_parent_bar); ASSERT_TRUE(foo_bar_parent_bar.has_value());
EXPECT_EQ(foo_bar_parent_bar->kind_case(), SavedObject::kUserObject); EXPECT_EQ(*foo_bar_parent_bar, 3);
EXPECT_EQ(foo_bar_parent_bar->user_object().identifier(),
"_generic_user_object");
EXPECT_EQ(foo, foo_bar_parent); EXPECT_EQ(*foo, *foo_bar_parent);
EXPECT_EQ(foo_bar, foo_bar_parent_bar); EXPECT_EQ(*foo_bar, *foo_bar_parent_bar);
} }
} // namespace } // namespace

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# This package contains written convenience helpers for Eager Operations # This package contains written convenience helpers for Eager Operations
# used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these. # used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these.
load( load(

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# This package contains classes corresponding to Revived SavedObjectGraph types # This package contains classes corresponding to Revived SavedObjectGraph types
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62 # used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
package( package(
@ -8,6 +10,25 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
cc_library(
name = "asset",
srcs = [
"asset.cc",
],
hdrs = [
"asset.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library( cc_library(
name = "constant", name = "constant",
srcs = [ srcs = [
@ -28,6 +49,106 @@ cc_library(
], ],
) )
cc_library(
name = "flat_tensor_function",
srcs = [
"flat_tensor_function.cc",
],
hdrs = [
"flat_tensor_function.h",
],
deps = [
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "partially_revived_objects",
srcs = [
"partially_revived_objects.cc",
],
hdrs = [
"partially_revived_objects.h",
],
deps = [
":asset",
":constant",
":restored_resource",
":restored_resource_revival_state",
":revived_objects",
":tf_concrete_function",
":tf_concrete_function_revival_state",
":tf_signature_def_function",
":tf_signature_def_function_revival_state",
":variable",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "restored_resource",
srcs = [
"restored_resource.cc",
],
hdrs = [
"restored_resource.h",
],
deps = [
":tensorhandle_convertible",
":tf_concrete_function",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "restored_resource_revival_state",
hdrs = [
"restored_resource_revival_state.h",
],
deps = [
":tf_concrete_function_revival_state",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)
cc_library(
name = "revived_objects",
hdrs = [
"revived_objects.h",
],
deps = [
":asset",
":constant",
":restored_resource",
":tf_concrete_function",
":tf_signature_def_function",
":variable",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "variable", name = "variable",
srcs = [ srcs = [
@ -45,6 +166,8 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )
@ -68,7 +191,7 @@ cc_library(
"tf_concrete_function.h", "tf_concrete_function.h",
], ],
deps = [ deps = [
":tensorhandle_convertible", ":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
@ -81,3 +204,55 @@ cc_library(
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
cc_library(
name = "tf_concrete_function_revival_state",
hdrs = [
"tf_concrete_function_revival_state.h",
],
deps = [
":tf_concrete_function",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "tf_signature_def_function",
srcs = [
"tf_signature_def_function.cc",
],
hdrs = [
"tf_signature_def_function.h",
],
deps = [
":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_signature_def_function_revival_state",
hdrs = [
"tf_signature_def_function_revival_state.h",
],
deps = [
":tf_signature_def_function",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/path.h"
namespace tensorflow {
Asset::Asset(ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)) {}
Status Asset::Create(ImmediateExecutionContext* ctx,
const std::string& saved_model_dir,
const std::string& asset_filename,
std::unique_ptr<Asset>* output) {
std::string abs_path =
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create scalar string tensor for Asset at path ", abs_path);
}
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
output->reset(new Asset(std::move(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
class Asset : public TensorHandleConvertible {
public:
static Status Create(ImmediateExecutionContext* ctx,
const std::string& saved_model_dir,
const std::string& asset_filename,
std::unique_ptr<Asset>* output);
// Asset is movable, but not copyable.
Asset(Asset&& other) = default;
Asset& operator=(Asset&& other) = default;
~Asset() override = default;
private:
explicit Asset(ImmediateTensorHandlePtr handle);
Asset(const Asset&) = delete;
Asset& operator=(const Asset&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_

View File

@ -0,0 +1,91 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include <memory>
#include <string>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
FlatTensorFunction::FlatTensorFunction(
const std::string& name, std::vector<ImmediateTensorHandlePtr> captures,
ImmediateExecutionContext* ctx)
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
FlatTensorFunction::~FlatTensorFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status FlatTensorFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
std::vector<ImmediateTensorHandlePtr> owned_captures;
owned_captures.reserve(captures.size());
for (ImmediateExecutionTensorHandle* capture : captures) {
capture->Ref();
owned_captures.push_back(ImmediateTensorHandlePtr(capture));
}
out->reset(new FlatTensorFunction(function_def->signature().name(),
std::move(owned_captures), ctx));
return Status();
}
Status FlatTensorFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
}
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More