Merge branch 'master' into 42129-tf.image.crop_and_resize
This commit is contained in:
commit
1816c43041
47
.bazelrc
47
.bazelrc
@ -5,6 +5,7 @@
|
||||
# Android options:
|
||||
# android:
|
||||
# android_arm:
|
||||
# android_arm64:
|
||||
# android_x86:
|
||||
# android_x86_64:
|
||||
#
|
||||
@ -46,10 +47,6 @@
|
||||
# using_cuda: CUDA is available to build system.
|
||||
# cuda: Build with full cuda support.
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
# sycl: Build with SYCL support.
|
||||
# sycl_nodouble:
|
||||
# sycl_asan:
|
||||
# sycl_trisycl:
|
||||
# mkl: Enable full mkl support.
|
||||
# tensorrt: Enable Tensorrt support.
|
||||
# ngraph: Enable ngraph support.
|
||||
@ -89,6 +86,7 @@
|
||||
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
|
||||
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
|
||||
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
|
||||
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
|
||||
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
|
||||
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
|
||||
|
||||
@ -161,13 +159,11 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
||||
# environment variable "TF_MKL_ROOT" every time before build.
|
||||
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl_threadpool --define=build_with_mkl_opensource=true
|
||||
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
|
||||
build:mkl_threadpool -c opt
|
||||
@ -175,10 +171,15 @@ build:mkl_threadpool -c opt
|
||||
# Config setting to build with oneDNN and without the binary blob
|
||||
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
||||
build:mkl_opensource_only -c opt
|
||||
|
||||
# Config setting to build with oneDNN for Arm.
|
||||
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
|
||||
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_aarch64 --define=build_with_mkl_opensource=true
|
||||
build:mkl_aarch64 -c opt
|
||||
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -216,19 +217,6 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --action_env TF_NEED_ROCM=1
|
||||
|
||||
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
|
||||
build:sycl --define=using_sycl=true
|
||||
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
|
||||
|
||||
build:sycl_nodouble --config=sycl
|
||||
build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
|
||||
|
||||
build:sycl_nodouble --config=sycl
|
||||
build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
|
||||
|
||||
build:sycl_nodouble --config=sycl
|
||||
build:sycl_trisycl --define=using_trisycl=true
|
||||
|
||||
# Options extracted from configure script
|
||||
build:ngraph --define=with_ngraph_support=true
|
||||
build:numa --define=with_numa_support=true
|
||||
@ -293,6 +281,7 @@ build:ios --noenable_platform_specific_config
|
||||
build:android --copt=-w
|
||||
build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:linux --host_copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
|
||||
@ -334,6 +323,11 @@ build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
|
||||
build:windows --copt=-DNOGDI
|
||||
build:windows --host_copt=-DNOGDI
|
||||
|
||||
# MSVC (Windows): Standards-conformant preprocessor mode
|
||||
# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview
|
||||
build:windows --copt=/experimental:preprocessor
|
||||
build:windows --host_copt=/experimental:preprocessor
|
||||
|
||||
# Misc build options we need for windows.
|
||||
build:windows --linkopt=/DEBUG
|
||||
build:windows --host_linkopt=/DEBUG
|
||||
@ -358,6 +352,7 @@ build --config=short_logs
|
||||
# TODO(gunan): Create a feature in toolchains for avx/avx2 to
|
||||
# avoid having to define linux/win separately.
|
||||
build:avx_linux --copt=-mavx
|
||||
build:avx_linux --host_copt=-mavx
|
||||
build:avx2_linux --copt=-mavx2
|
||||
build:native_arch_linux --copt=-march=native
|
||||
build:avx_win --copt=/arch=AVX
|
||||
@ -411,9 +406,12 @@ build:rbe_linux --config=avx_linux
|
||||
build:rbe_linux --config=short_logs
|
||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||
build:rbe_linux --linkopt=-lrt
|
||||
build:rbe_linux --host_linkopt=-lrt
|
||||
build:rbe_linux --linkopt=-lm
|
||||
build:rbe_linux --host_linkopt=-lm
|
||||
|
||||
build:rbe_cpu_linux --config=rbe_linux
|
||||
build:rbe_cpu_linux --host_crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
|
||||
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
@ -431,6 +429,7 @@ test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/
|
||||
|
||||
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
@ -447,6 +446,7 @@ build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo
|
||||
|
||||
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
@ -587,7 +587,7 @@ build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
|
||||
build:release_gpu_common --action_env=TF_CUDA_VERSION="11"
|
||||
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8"
|
||||
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
|
||||
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
|
||||
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
|
||||
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
|
||||
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
|
||||
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
|
||||
@ -603,3 +603,8 @@ build:release_windows_common --announce_rc
|
||||
build:release_cpu_windows --config=release_windows_common
|
||||
|
||||
build:release_gpu_windows --config=release_windows_common
|
||||
|
||||
build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux
|
||||
build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
|
||||
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
|
||||
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"
|
||||
|
22
.github/bot_config.yml
vendored
22
.github/bot_config.yml
vendored
@ -12,12 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
@ -40,6 +34,22 @@ segfault_memory:
|
||||
# assignees
|
||||
filesystem_security_assignee:
|
||||
- mihaimaruseac
|
||||
|
||||
tflite_micro_path:
|
||||
- tensorflow/lite/micro
|
||||
|
||||
tflite_micro_comment: >
|
||||
Thanks for contributing to TensorFlow Lite Micro.
|
||||
|
||||
|
||||
To keep this process moving along, we'd like to make sure that you have completed the items on this list:
|
||||
* Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md)
|
||||
* Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
|
||||
* Linked to the issue from the PR description
|
||||
|
||||
|
||||
We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review.
|
||||
|
||||
# Cuda Comment
|
||||
cuda_comment: >
|
||||
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
|
||||
|
@ -1,4 +1,3 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -12,14 +11,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
set -e
|
||||
set -x
|
||||
# ============================================================================
|
||||
|
||||
source tensorflow/tools/ci_build/release/common.sh
|
||||
|
||||
# Rename to tensorflow_cpu
|
||||
for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*m-win_amd64.whl); do
|
||||
copy_to_new_project_name "${f}" tensorflow_cpu
|
||||
rm "${f}"
|
||||
done
|
||||
on:
|
||||
workflow_dispatch: # Allow manual triggers
|
||||
schedule:
|
||||
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
|
||||
name: Set nightly branch to master HEAD
|
||||
jobs:
|
||||
master-to-nightly:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: zofrex/mirror-branch@v1
|
||||
name: Set nightly branch to master HEAD
|
||||
with:
|
||||
target-branch: 'nightly'
|
10
ADOPTERS.md
10
ADOPTERS.md
@ -1,10 +0,0 @@
|
||||
# TensorFlow Adopters
|
||||
|
||||
This page contains a list of people and organizations who are using TensorFlow. If you'd like to be included
|
||||
here, please send a pull request which modifies this file.
|
||||
|
||||
We intend to use this list to contact you for surveys, and to find good candidates for invite-only events.
|
||||
We will also point to this list if we are asked who uses TensorFlow.
|
||||
|
||||
We will not use any of the information here for promotions or to send other regular communications. You
|
||||
should subscribe to discuss@tensorflow.org for such announcements.
|
@ -1,16 +1,15 @@
|
||||
# Where component owners are known, add them here.
|
||||
|
||||
/tensorflow/c/eager @jaingaurav @alextp
|
||||
/tensorflow/core/common_runtime/eager @jaingaurav @alextp
|
||||
/tensorflow/c/eager @qqfish @kkimdev
|
||||
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
|
||||
/tenosrflow/core/debug @caisq
|
||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||
/tensorflow/core/platform/windows/ @mrry
|
||||
/tensorflow/core/platform/windows/ @mihaimaruseac
|
||||
/tensorflow/lite/experimental/micro @petewarden @advaitjain
|
||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||
/tensorflow/python/debug @caisq
|
||||
/tensorflow/python/eager @jaingaurav @alextp
|
||||
/tensorflow/python/eager @rohan100jain @kkimdev
|
||||
/tensorflow/python/tools/api/generator/ @annarev
|
||||
/tensorflow/tensorboard/ @jart
|
||||
/tensorflow/tools/docs/ @markdaoust
|
||||
|
||||
/third_party/systemlibs/ @perfinion
|
||||
|
38
README.md
38
README.md
@ -103,23 +103,22 @@ open-source software development:
|
||||
|
||||
### Official Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
@ -145,19 +144,20 @@ Build Type
|
||||
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
|
||||
* [TensorFlow Examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
|
||||
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the
|
||||
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [TensorFlow Roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap)
|
||||
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
|
||||
|
||||
|
727
RELEASE.md
727
RELEASE.md
@ -34,9 +34,33 @@
|
||||
shape assumptions (note that you can pass shapes with `None` entries for axes
|
||||
that are meant to be dynamic). You can also disable the input checking
|
||||
entirely by setting `model.input_spec = None`.
|
||||
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
|
||||
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
||||
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
||||
removed).
|
||||
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
|
||||
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
|
||||
well defined for complex types.
|
||||
* `tf.data.experimental.service.DispatchServer` now takes a config tuple
|
||||
instead of individual arguments. Usages should be updated to
|
||||
`tf.data.experimental.service.DispatchServer(dispatcher_config)`.
|
||||
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
||||
instead of individual arguments. Usages should be updated to
|
||||
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
||||
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
|
||||
updates the gradient definition for quantization which is outside the range
|
||||
to be 0. To simulate the V1 the behavior of
|
||||
tf.quantization.quantize_and_dequantize(...) use
|
||||
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
|
||||
use `tf.data.Dataset.from_tensor_slices` instead.
|
||||
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
|
||||
`tf.distribute.StrategyExtended.batch_reduce_to`,
|
||||
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
|
||||
`tf.distribute.experimental.CollectiveHints` is renamed
|
||||
`tf.distribute.experimental.CommunicationOptions`.
|
||||
`tf.distribute.experimental.CollectiveCommunication` is renamed
|
||||
`tf.distribute.experimental.CommunicationImplementation`.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
@ -46,89 +70,180 @@
|
||||
|
||||
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See tensorflow/python/ops/numpy_ops/README.md for details of what are supported and what are the differences with NumPy.
|
||||
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
|
||||
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
|
||||
|
||||
* `tf.distribute`:
|
||||
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* TF Core:
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
|
||||
type annotation for variables representing a Tensor or a value that can be
|
||||
converted to Tensor by `tf.convert_to_tensor`.
|
||||
* Calling ops with a python constants or numpy values is now consistent with
|
||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
||||
truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
||||
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
|
||||
and `__invert__` now support non-`bool` arguments and apply the
|
||||
corresponding bitwise ops. `bool` arguments continue to be supported and
|
||||
dispatch to logical ops. This brings them more in line with Python and NumPy
|
||||
benavior.
|
||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
|
||||
the same sparsity pattern, but with new provided values. It is similar to
|
||||
the `with_values` function of `RaggedTensor`.
|
||||
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
|
||||
* `tf.data`:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
|
||||
to register a dataset with the tf.data service, and another process to
|
||||
consume data from the dataset.
|
||||
* Added support for tf.data service dispatcher fault tolerance. To enable
|
||||
fault tolerance, configure a `work_dir` when running your dispatcher
|
||||
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
|
||||
store its state to `work_dir`, so that on restart it can continue from its
|
||||
previous state after restart.
|
||||
* Added tf.data service support for sharing dataset graphs via shared
|
||||
filesystem instead of over RPC. This reduces load on the dispatcher,
|
||||
improving performance of distributing datasets. For this to work, the
|
||||
dispatcher's `work_dir` must be accessible from workers. If the worker
|
||||
fails to read from the `work_dir`, it falls back to using RPC for dataset
|
||||
graph transfer.
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
* We have implemented an optimization which reorders data-discarding
|
||||
transformations such as `take` and `shard` to happen earlier in the
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
`stateless_sample_distorted_bounding_box` which is a determinstic
|
||||
version of `sample_distorted_bounding_box` op. Given the same seed, these
|
||||
stateless functions/ops produce the same results independent of how many
|
||||
times the function is called, and independent of global seed settings.
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* Security:
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
|
||||
`SparseCountSparseOutput` operations
|
||||
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
|
||||
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
|
||||
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
|
||||
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
||||
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
||||
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder
|
||||
API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from
|
||||
`tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in
|
||||
TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Fixes several vulnerabilities in TFLite implementation of segment sum
|
||||
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||
* TF Core:
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be
|
||||
used as type annotation for variables representing a Tensor or a value
|
||||
that can be converted to Tensor by `tf.convert_to_tensor`.
|
||||
* Calling ops with a python constants or numpy values is now consistent
|
||||
with tf.convert_to_tensor behavior. This avoids operations like
|
||||
tf.reshape truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of
|
||||
`SparseTensor` arguments.
|
||||
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
|
||||
`__xor__` and `__invert__` now support non-`bool` arguments and apply
|
||||
the corresponding bitwise ops. `bool` arguments continue to be supported
|
||||
and dispatch to logical ops. This brings them more in line with Python
|
||||
and NumPy behavior.
|
||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
|
||||
with the same sparsity pattern, but with new provided values. It is
|
||||
similar to the `with_values` function of `RaggedTensor`.
|
||||
* Added `StatelessCase` op, and uses it if none of case branches has
|
||||
stateful ops.
|
||||
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||
usage of the device.
|
||||
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
||||
* `tf.data`:
|
||||
* tf.data service:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one
|
||||
process to register a dataset with the tf.data service, and another
|
||||
process to consume data from the dataset.
|
||||
* Added support for dispatcher fault tolerance. To enable fault tolerance,
|
||||
configure a `work_dir` when running your dispatcher server and set
|
||||
`dispatcher_fault_tolerance=True`. The dispatcher will store its state
|
||||
to `work_dir`, so that on restart it can continue from its previous
|
||||
state after restart.
|
||||
* Added support for sharing dataset graphs via shared filesystem instead
|
||||
of over RPC. This reduces load on the dispatcher, improving performance
|
||||
of distributing datasets. For this to work, the dispatcher's `work_dir`
|
||||
must be accessible from workers. If the worker fails to read from the
|
||||
`work_dir`, it falls back to using RPC for dataset graph transfer.
|
||||
* Added support for a new "distributed_epoch" processing mode. This
|
||||
processing mode distributes a dataset across all tf.data workers,
|
||||
instead of having each worker process the full dataset. See
|
||||
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
|
||||
to learn more.
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be
|
||||
specified.
|
||||
* We have implemented an optimization which reorders data-discarding
|
||||
transformations such as `take` and `shard` to happen earlier in the
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overridden.
|
||||
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
||||
with a new `output_signature` argument, which allows `from_generator` to
|
||||
produce any type describable by a `tf.TypeSpec`.
|
||||
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
|
||||
`tf.data.AUTOTUNE`.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
`stateless_sample_distorted_bounding_box` which is a deterministic
|
||||
version of `sample_distorted_bounding_box` op. Given the same seed,
|
||||
these stateless functions/ops produce the same results independent of
|
||||
how many times the function is called, and independent of global seed
|
||||
settings.
|
||||
* `tf.distribute`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.keras`:
|
||||
* Improvements from the functional API refactoring:
|
||||
* Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models.
|
||||
* Functional model construction should be ~8-10% faster on average.
|
||||
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument.
|
||||
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale`
|
||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||
as an alternative to accepting a `callable` loss.
|
||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
||||
* Added `mobilenet_v3` to keras application model.
|
||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
`gradients_transformers` to allow for custom gradient transformations
|
||||
(such as gradient clipping).
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
performance.
|
||||
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
||||
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
||||
the values of these symbols at an iteration does not depend on the previous
|
||||
iteration. These types of loops must run at least one iteration, and will
|
||||
raise a runtime error otherwise.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.keras`:
|
||||
* Improvements from the functional API refactoring:
|
||||
* Functional model construction does not need to maintain a global
|
||||
workspace graph, removing memory leaks especially when building many
|
||||
models or very large models.
|
||||
* Functional model construction should be ~8-10% faster on average.
|
||||
* Functional models can now contain non-symbolic values in their call
|
||||
inputs inside of the first positional argument.
|
||||
* Several classes of TF ops that were not reliably converted to Keras
|
||||
layers during functional API construction should now work, e.g.
|
||||
`tf.image.ssim_multiscale`
|
||||
* Error messages when Functional API construction goes wrong (and when
|
||||
ops cannot be converted to Keras layers automatically) should be
|
||||
clearer and easier to understand.
|
||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||
as an alternative to accepting a `callable` loss.
|
||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||
to match FTRL paper
|
||||
(https://research.google.com/pubs/archive/41159.pdf).
|
||||
* Added `mobilenet_v3` to keras application model.
|
||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
`gradients_transformers` to allow for custom gradient transformations
|
||||
(such as gradient clipping).
|
||||
* The `steps_per_execution` argument in `compile()` is no longer
|
||||
experimental; if you were passing `experimental_steps_per_execution`,
|
||||
rename it to `steps_per_execution` in your code. This argument controls
|
||||
the number of batches to run during each `tf.function` call when calling
|
||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||
greatly improve performance on TPUs or small models with a large Python
|
||||
overhead.
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* TextVectorization can now accept a vocabulary list or file as an
|
||||
init arg.
|
||||
* Normalization can now accept mean and variance values as init args.
|
||||
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||
accepts a `return_attention_scores` argument. When set to
|
||||
True, the layer returns the attention scores as an additional output
|
||||
argument.
|
||||
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||
with the same implementation as their `tf.losses` equivalent.
|
||||
* For Keras model, the individual call of `Model.evaluate` uses no cached
|
||||
data for evaluation, while `Model.fit` uses cached data when
|
||||
`validation_data` arg is provided for better performance.
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
performance.
|
||||
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
||||
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
||||
the values of these symbols at an iteration does not depend on the
|
||||
previous iteration. These types of loops must run at least one
|
||||
iteration, and will raise a runtime error otherwise.
|
||||
|
||||
Example:
|
||||
|
||||
@ -137,45 +252,103 @@
|
||||
outputs = train_step(batch)
|
||||
tf.print('final outputs', outputs)
|
||||
```
|
||||
|
||||
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
|
||||
info.
|
||||
|
||||
* `tf.lite`:
|
||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||
string to be joined is empty.
|
||||
* `TFLiteConverter`:
|
||||
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
|
||||
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
|
||||
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
|
||||
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `TFLiteConverter`:
|
||||
* Support optional flags `inference_input_type` and
|
||||
`inference_output_type` for full integer quantized models. This
|
||||
allows users to modify the model input and output type to integer
|
||||
types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
|
||||
(`tf.float32`).
|
||||
* TFLite Profiler for Android is available. See the detailed
|
||||
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
||||
* NNAPI
|
||||
* Added NNAPI Delegation support for requantization use cases by
|
||||
converting the operation into a dequantize-quantize pair.
|
||||
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
|
||||
* Use `Interpreter.Options.setUseNNAPI` instead.
|
||||
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
directly.
|
||||
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
|
||||
* Prefer controlling this via delegate options, e.g.
|
||||
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
|
||||
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||
string to be joined is empty.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Math and Linear Algebra:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
|
||||
|
||||
* TPU Enhancements:
|
||||
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
||||
embeddings. Users of other TensorFlow platforms can implement equivalent
|
||||
behavior by adjusting the `l2` parameter.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
||||
embeddings. Users of other TensorFlow platforms can implement equivalent
|
||||
behavior by adjusting the `l2` parameter.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* XLA Support:
|
||||
* xla.experimental.compile is deprecated, use
|
||||
`tf.function(experimental_compile=True)` instead
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* xla.experimental.compile is deprecated, use
|
||||
`tf.function(experimental_compile=True)` instead
|
||||
* Added `tf.function.experimental_get_compiler_ir` which returns compiler
|
||||
IR (currently 'hlo' and 'optimized_hlo') for given input for given
|
||||
function.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `tf.train.Checkpoint`:
|
||||
* Now accepts a `root` argument in the initialization, which generates a
|
||||
checkpoint with a root object. This allows users to create a `Checkpoint`
|
||||
object that is compatible with Keras `model.save_weights()` and
|
||||
`model.load_weights`. The checkpoint is also compatible with the
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
|
||||
* Now accepts a `root` argument in the initialization, which generates a
|
||||
checkpoint with a root object. This allows users to create a
|
||||
`Checkpoint` object that is compatible with Keras `model.save_weights()`
|
||||
and `model.load_weights`. The checkpoint is also compatible with the
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
|
||||
* `tf.nn`:
|
||||
|
||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||
|
||||
* `tf.debugging`:
|
||||
|
||||
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
||||
|
||||
* `tf.print`:
|
||||
|
||||
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
|
||||
didn't have the keys sorted, the keys and values were not being printed
|
||||
in accordance with their correct mapping.
|
||||
|
||||
* `TensorRT`
|
||||
|
||||
* We now issue a warning when the `session_config` parameter for the TF1
|
||||
converter is used or the `rewrite_config_template` field in the TF2
|
||||
converter parameter object is used.
|
||||
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more context.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more
|
||||
context.
|
||||
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
|
||||
rollout the new MLIR TPU bridge.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
@ -183,45 +356,327 @@ This release contains contributions from many people at Google, as well as:
|
||||
|
||||
stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
|
||||
|
||||
# Release 2.3.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
|
||||
`SparseCountSparseOutput` operations
|
||||
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
|
||||
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
|
||||
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
|
||||
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
||||
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
||||
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Fixes several vulnerabilities in TFLite implementation of segment sum
|
||||
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||
* Updates `sqlite3` to `3.33.00` to handle
|
||||
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
|
||||
* Fixes deprecated usage of `collections` API
|
||||
* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it
|
||||
to install the pip package
|
||||
|
||||
|
||||
# Release 2.2.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Fixes several vulnerabilities in TFLite implementation of segment sum
|
||||
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||
* Updates `sqlite3` to `3.33.00` to handle
|
||||
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
|
||||
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
|
||||
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
|
||||
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
|
||||
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
|
||||
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
|
||||
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
|
||||
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
|
||||
and
|
||||
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
|
||||
* Fixes deprecated usage of `collections` API
|
||||
* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it
|
||||
to install the pip package
|
||||
|
||||
|
||||
# Release 2.1.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Updates `sqlite3` to `3.33.00` to handle
|
||||
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
|
||||
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
|
||||
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
|
||||
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
|
||||
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
|
||||
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
|
||||
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
|
||||
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
|
||||
and
|
||||
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
|
||||
* Removes `scipy` dependency from `setup.py` since TensorFlow does not need it
|
||||
to install the pip package
|
||||
* Switches ROCM builds to use ROCM 3.7
|
||||
|
||||
|
||||
# Release 2.0.3
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Updates `sqlite3` to `3.33.00` to handle
|
||||
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
|
||||
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
|
||||
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
|
||||
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
|
||||
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
|
||||
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
|
||||
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
|
||||
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
|
||||
and
|
||||
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
|
||||
* Pins `numpy` to 1.18.5 to prevent ABI breakage when compiling code that uses
|
||||
both NumPy and TensorFlow headers.
|
||||
|
||||
|
||||
# Release 1.15.4
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Updates `sqlite3` to `3.33.00` to handle
|
||||
[CVE-2020-9327](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-9327),
|
||||
[CVE-2020-11655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11655),
|
||||
[CVE-2020-11656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-11656),
|
||||
[CVE-2020-13434](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13434),
|
||||
[CVE-2020-13435](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13435),
|
||||
[CVE-2020-13630](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13630),
|
||||
[CVE-2020-13631](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13631),
|
||||
[CVE-2020-13871](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13871),
|
||||
and
|
||||
[CVE-2020-15358](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15358).
|
||||
* Fixes #41630 by including `max_seq_length` in CuDNN descriptor cache key
|
||||
* Pins `numpy` to 1.18.5 to prevent ABI breakage when compiling code that uses
|
||||
both NumPy and TensorFlow headers.
|
||||
|
||||
|
||||
# Release 2.3.0
|
||||
|
||||
## Major Features and Improvements
|
||||
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
|
||||
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
||||
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
||||
|
||||
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
|
||||
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and
|
||||
save resources:
|
||||
|
||||
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
|
||||
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
||||
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
||||
|
||||
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
|
||||
In addition checkout the detailed
|
||||
[guide](https://www.tensorflow.org/guide/data_performance_analysis) for
|
||||
analyzing input pipeline performance with TF Profiler.
|
||||
|
||||
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
|
||||
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
|
||||
is now a stable API and no longer considered experimental for TensorFlow.
|
||||
(earlier `tf.distribute.experimental.TPUStrategy`).
|
||||
|
||||
* TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
||||
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new
|
||||
tools: a memory profiler to visualize your model’s memory usage over time
|
||||
and a [python tracer](https://www.tensorflow.org/guide/profiler#events)
|
||||
which allows you to trace python function calls in your model. Usability
|
||||
improvements include better diagnostic messages and
|
||||
[profile options](https://tensorflow.org/guide/profiler#collect_performance_data)
|
||||
to customize the host and device trace verbosity level.
|
||||
|
||||
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
||||
* Introduces experimental support for Keras Preprocessing Layers API
|
||||
([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly))
|
||||
to handle data preprocessing operations, with support for composite tensor
|
||||
inputs. Please see below for additional details on these layers.
|
||||
|
||||
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
|
||||
* TFLite now properly supports dynamic shapes during conversion and inference.
|
||||
We’ve also added opt-in support on Android and iOS for
|
||||
[XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack),
|
||||
a highly optimized set of CPU kernels, as well as opt-in support for
|
||||
[executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
||||
|
||||
* Libtensorflow packages are available in GCS starting this release. We have
|
||||
also started to
|
||||
[release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
||||
|
||||
* The experimental Python API
|
||||
[`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info)
|
||||
now allows you to instrument a TensorFlow program and dump debugging
|
||||
information to a directory on the file system. The directory can be read and
|
||||
visualized by a new interactive dashboard in TensorBoard 2.3 called
|
||||
[Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which
|
||||
reveals the details of the TensorFlow program including graph structures,
|
||||
history of op executions at the Python (eager) and intra-graph levels, the
|
||||
runtime dtype, shape, and numerical composition of tensors, as well as their
|
||||
code locations.
|
||||
|
||||
## Breaking Changes
|
||||
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
||||
* `tf.data`
|
||||
* Makes the following (breaking) changes to the `tf.data`.
|
||||
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
|
||||
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
|
||||
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
|
||||
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
|
||||
* `tf.keras`
|
||||
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
|
||||
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
||||
|
||||
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
||||
* `tf.data`
|
||||
* Makes the following (breaking) changes to the `tf.data`.
|
||||
* C++ API: - `IteratorBase::RestoreInternal`,
|
||||
`IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState`
|
||||
become pure-virtual and subclasses are now expected to provide an
|
||||
implementation.
|
||||
* The deprecated `DatasetBase::IsStateful` method is removed in favor of
|
||||
`DatasetBase::CheckExternalState`.
|
||||
* Deprecated overrides of `DatasetBase::MakeIterator` and
|
||||
`MakeIteratorFromInputElement` are removed.
|
||||
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and
|
||||
`tensorflow::data::IteratorBase::SaveInput` has been extended with
|
||||
`SerializationContext` argument to enable overriding the default policy
|
||||
for the handling external state during iterator checkpointing. This is
|
||||
not a backwards compatible change and all subclasses of `IteratorBase`
|
||||
*need to be updated* accordingly.
|
||||
* `tf.keras`
|
||||
* Add a new `BackupAndRestore` callback for handling distributed training
|
||||
failures & restarts. Please take a look at this
|
||||
[tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
|
||||
for details on how to use the callback.
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
existing C++ kernel `ExtractGlimpse` does not change either, so saved models
|
||||
using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
||||
|
||||
## Known Caveats
|
||||
* `tf.lite`
|
||||
@ -1211,8 +1666,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
|
||||
conversion. TensorRT initialization arguments are now passed wrapped in
|
||||
a named-tuple, `TrtConversionParams`, rather than as separate arguments
|
||||
as in `TrtGraphConverter`.
|
||||
* Changed API to optimize TensorRT enginges during graph optimization.
|
||||
This is now done by calling `converter.build()` where previously
|
||||
* Changed API to optimize TensorRT engines during graph optimization. This
|
||||
is now done by calling `converter.build()` where previously
|
||||
`is_dynamic_op=False` would be set.
|
||||
* `converter.convert()` no longer returns a `tf.function`. Now the
|
||||
function must be accessed from the saved model.
|
||||
|
88
configure.py
88
configure.py
@ -38,9 +38,6 @@ _DEFAULT_CUDNN_VERSION = '7'
|
||||
_DEFAULT_TENSORRT_VERSION = '6'
|
||||
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
|
||||
|
||||
_TF_OPENCL_VERSION = '1.2'
|
||||
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
|
||||
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18]
|
||||
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
|
||||
@ -1114,62 +1111,6 @@ def set_host_c_compiler(environ_cp):
|
||||
write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
|
||||
|
||||
|
||||
def set_computecpp_toolkit_path(environ_cp):
|
||||
"""Set COMPUTECPP_TOOLKIT_PATH."""
|
||||
|
||||
def toolkit_exists(toolkit_path):
|
||||
"""Check if a computecpp toolkit path is valid."""
|
||||
if is_linux():
|
||||
sycl_rt_lib_path = 'lib/libComputeCpp.so'
|
||||
else:
|
||||
sycl_rt_lib_path = ''
|
||||
|
||||
sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path)
|
||||
exists = os.path.exists(sycl_rt_lib_path_full)
|
||||
if not exists:
|
||||
print('Invalid SYCL %s library path. %s cannot be found' %
|
||||
(_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
|
||||
return exists
|
||||
|
||||
computecpp_toolkit_path = prompt_loop_or_load_from_env(
|
||||
environ_cp,
|
||||
var_name='COMPUTECPP_TOOLKIT_PATH',
|
||||
var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH,
|
||||
ask_for_var=(
|
||||
'Please specify the location where ComputeCpp for SYCL %s is '
|
||||
'installed.' % _TF_OPENCL_VERSION),
|
||||
check_success=toolkit_exists,
|
||||
error_msg='Invalid SYCL compiler path. %s cannot be found.',
|
||||
suppress_default_error=True)
|
||||
|
||||
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
|
||||
computecpp_toolkit_path)
|
||||
|
||||
|
||||
def set_trisycl_include_dir(environ_cp):
|
||||
"""Set TRISYCL_INCLUDE_DIR."""
|
||||
|
||||
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
|
||||
'include directory. (Use --config=sycl_trisycl '
|
||||
'when building with Bazel) '
|
||||
'[Default is %s]: ') % (
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR)
|
||||
|
||||
while True:
|
||||
trisycl_include_dir = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR)
|
||||
if os.path.exists(trisycl_include_dir):
|
||||
break
|
||||
|
||||
print('Invalid triSYCL include directory, %s cannot be found' %
|
||||
(trisycl_include_dir))
|
||||
|
||||
# Set TRISYCL_INCLUDE_DIR
|
||||
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
|
||||
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
|
||||
|
||||
|
||||
def system_specific_test_config(environ_cp):
|
||||
"""Add default build and test flags required for TF tests to bazelrc."""
|
||||
write_to_bazelrc('test --flaky_test_attempts=3')
|
||||
@ -1397,8 +1338,6 @@ def main():
|
||||
setup_python(environ_cp)
|
||||
|
||||
if is_windows():
|
||||
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
|
||||
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
||||
environ_cp['TF_NEED_OPENCL'] = '0'
|
||||
environ_cp['TF_CUDA_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
@ -1415,21 +1354,6 @@ def main():
|
||||
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
|
||||
write_to_bazelrc('build --config=xla')
|
||||
|
||||
set_action_env_var(
|
||||
environ_cp,
|
||||
'TF_NEED_OPENCL_SYCL',
|
||||
'OpenCL SYCL',
|
||||
False,
|
||||
bazel_config_name='sycl')
|
||||
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
|
||||
set_host_cxx_compiler(environ_cp)
|
||||
set_host_c_compiler(environ_cp)
|
||||
set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True)
|
||||
if environ_cp.get('TF_NEED_COMPUTECPP') == '1':
|
||||
set_computecpp_toolkit_path(environ_cp)
|
||||
else:
|
||||
set_trisycl_include_dir(environ_cp)
|
||||
|
||||
set_action_env_var(
|
||||
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
|
||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and
|
||||
@ -1442,6 +1366,11 @@ def main():
|
||||
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
|
||||
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
|
||||
|
||||
if ((environ_cp.get('TF_NEED_ROCM') == '1') and
|
||||
(environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')):
|
||||
write_to_bazelrc(
|
||||
'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1')
|
||||
|
||||
environ_cp['TF_NEED_CUDA'] = str(
|
||||
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
@ -1523,17 +1452,15 @@ def main():
|
||||
# use it for the CPU build.
|
||||
set_tf_download_clang(environ_cp)
|
||||
|
||||
# SYCL / ROCm / CUDA are mutually exclusive.
|
||||
# ROCm / CUDA are mutually exclusive.
|
||||
# At most 1 GPU platform can be configured.
|
||||
gpu_platform_count = 0
|
||||
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
|
||||
gpu_platform_count += 1
|
||||
if environ_cp.get('TF_NEED_ROCM') == '1':
|
||||
gpu_platform_count += 1
|
||||
if environ_cp.get('TF_NEED_CUDA') == '1':
|
||||
gpu_platform_count += 1
|
||||
if gpu_platform_count >= 2:
|
||||
raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
|
||||
raise UserInputError('CUDA / ROCm are mututally exclusive. '
|
||||
'At most 1 GPU platform can be configured.')
|
||||
|
||||
set_cc_opt_flags(environ_cp)
|
||||
@ -1558,6 +1485,7 @@ def main():
|
||||
'adding "--config=<>" to your build command. See .bazelrc for more '
|
||||
'details.')
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
||||
config_info_line('numa', 'Build with NUMA support.')
|
||||
|
@ -497,13 +497,20 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental MLIR bridge support.
|
||||
# This flag forcibly enables experimental MLIR bridge support.
|
||||
config_setting(
|
||||
name = "enable_mlir_bridge",
|
||||
values = {"define": "enable_mlir_bridge=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag forcibly disables experimental MLIR bridge support.
|
||||
config_setting(
|
||||
name = "disable_mlir_bridge",
|
||||
values = {"define": "enable_mlir_bridge=false"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental TPU support
|
||||
config_setting(
|
||||
name = "with_tpu_support",
|
||||
@ -562,33 +569,17 @@ selects.config_setting_group(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//learning/lib/ami/simple_ml/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "ndarray_tensor_allow_list",
|
||||
packages = ["//learning/pathways/..."],
|
||||
)
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
# If this is modified, then copy.bara.sky must also be modified.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
package_group(
|
||||
name = "types_whitelist",
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
# If this is modified, then copy.bara.sky must also be modified.
|
||||
package_group(name = "types_whitelist")
|
||||
|
||||
# Packages that use StructuredTensors.
|
||||
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
||||
@ -714,8 +705,12 @@ tf_cc_shared_object(
|
||||
soversion = VERSION,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/c:kernels_hdrs",
|
||||
"//tensorflow/c:ops_hdrs",
|
||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||
"//tensorflow/core:core_cpu_impl",
|
||||
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
|
@ -138,12 +138,12 @@ if _running_from_pip_package():
|
||||
for _s in _site_packages_dirs:
|
||||
# Load first party dynamic kernels.
|
||||
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
|
||||
if _fi.file_exists(_main_dir):
|
||||
if _os.path.exists(_main_dir):
|
||||
_ll.load_library(_main_dir)
|
||||
|
||||
# Load third party dynamic kernels.
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
||||
# Add module aliases
|
||||
|
@ -148,12 +148,12 @@ if _running_from_pip_package():
|
||||
for _s in _site_packages_dirs:
|
||||
# Load first party dynamic kernels.
|
||||
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
|
||||
if _fi.file_exists(_main_dir):
|
||||
if _os.path.exists(_main_dir):
|
||||
_ll.load_library(_main_dir)
|
||||
|
||||
# Load third party dynamic kernels.
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Description:
|
||||
# C API for TensorFlow, for use by client language bindings.
|
||||
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
@ -9,6 +10,11 @@ load(
|
||||
"tf_custom_op_library",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
|
||||
package(
|
||||
@ -211,6 +217,8 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/kernels:logging_ops",
|
||||
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
|
||||
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
@ -248,6 +256,30 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_shape",
|
||||
srcs = ["tf_shape.cc"],
|
||||
hdrs = ["tf_shape.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_macros",
|
||||
":tf_shape_internal",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_shape_internal",
|
||||
hdrs = ["tf_shape_internal.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":conversion_macros",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_status",
|
||||
srcs = ["tf_status.cc"],
|
||||
@ -377,6 +409,7 @@ tf_cuda_library(
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:get_compiler_ir",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -387,6 +420,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -477,6 +511,18 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kernels_hdrs",
|
||||
hdrs = ["kernels.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "kernels",
|
||||
srcs = [
|
||||
@ -530,6 +576,16 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops_hdrs",
|
||||
hdrs = ["ops.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
|
@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||
TF_Status* status) {
|
||||
using tensorflow::RecordMutation;
|
||||
mutex_lock l(graph->mu);
|
||||
tensorflow::shape_inference::InferenceContext* ic =
|
||||
graph->refiner.GetContext(&new_src.oper->node);
|
||||
|
||||
if (ic->num_outputs() <= new_src.index) {
|
||||
status->status = tensorflow::errors::OutOfRange(
|
||||
"Cannot update edge. Output index [", new_src.index,
|
||||
"] is greater than the number of total outputs [", ic->num_outputs(),
|
||||
"].");
|
||||
return;
|
||||
}
|
||||
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
||||
|
||||
tensorflow::shape_inference::InferenceContext* ic_dst =
|
||||
graph->refiner.GetContext(&dst.oper->node);
|
||||
if (ic_dst->num_inputs() <= dst.index) {
|
||||
status->status = tensorflow::errors::OutOfRange(
|
||||
"Cannot update edge. Input index [", dst.index,
|
||||
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
||||
"].");
|
||||
return;
|
||||
}
|
||||
if (!ic_dst->MergeInput(dst.index, shape)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
||||
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
||||
return;
|
||||
}
|
||||
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
||||
&dst.oper->node, dst.index);
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
// This modification only updates the destination node for
|
||||
// the purposes of running this graph in a session. Thus, we don't
|
||||
// record the source node as being modified.
|
||||
RecordMutation(graph, *dst.oper, "updating input tensor");
|
||||
}
|
||||
}
|
||||
|
||||
// TF_Server functions ----------------------------------------------
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
|
@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
||||
const char* name, TF_Status* status);
|
||||
|
||||
// Update edge, switch input/ output in a node
|
||||
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
|
||||
TF_Input dst, TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// In-process TensorFlow server functionality, for use in distributed training.
|
||||
// A Server instance encapsulates a set of devices and a Session target that
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
collective_executor_handle->get()->StartAbort(status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
||||
tensorflow::Notification done;
|
||||
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
|
||||
task, [&done, status](const Status& s) {
|
||||
status->status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
}
|
||||
|
||||
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
|
||||
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
|
||||
result->num_items = num_items;
|
||||
|
@ -231,13 +231,20 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Aborts all ongoing collectives with the specified status. After abortion,
|
||||
// subsequent collectives will error with this status immediately.
|
||||
// subsequent collectives will error with this status immediately. To reset the
|
||||
// collectives, create a new EagerContext.
|
||||
//
|
||||
// This is intended to be used when a peer failure is detected. There's yet no
|
||||
// way to reset the collectives other than restarting the program.
|
||||
// This is intended to be used when a peer failure is detected.
|
||||
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Checks the health of collective ops peers. Explicit health check is needed in
|
||||
// multi worker collective ops to detect failures in the cluster. If a peer is
|
||||
// down, collective ops may hang.
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status);
|
||||
|
||||
// Information about the shape of a Tensor and its type.
|
||||
struct TF_ShapeAndType {
|
||||
// Number of dimensions. -1 indicates unknown rank.
|
||||
|
@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
|
||||
TF_DeleteFunction(func1);
|
||||
}
|
||||
|
||||
// This test only works when the TF build includes XLA compiler. One way to set
|
||||
// this up is via bazel build option "--define with_xla_support=true".
|
||||
//
|
||||
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
|
||||
// something like TENSORFLOW_CAPI_USE_XLA.
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
|
||||
TF_Function* func;
|
||||
const std::string funcName = "BranchFunc";
|
||||
DefineFunction(funcName.c_str(), &func);
|
||||
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* feed = Placeholder(host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_OperationDescription* desc =
|
||||
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
|
||||
TF_AddInput(desc, {true_cond, 0});
|
||||
TF_Output inputs[] = {{feed, 0}};
|
||||
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_SetAttrType(desc, "Tcond", TF_BOOL);
|
||||
TF_DataType inputType = TF_INT32;
|
||||
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
|
||||
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
|
||||
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
|
||||
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
|
||||
TF_SetDevice(desc, "/device:XLA_CPU:0");
|
||||
auto op = TF_FinishOperation(desc, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
ASSERT_NE(op, nullptr);
|
||||
|
||||
// Create a session for this graph.
|
||||
CSession csession(host_graph_, s_, /*use_XLA*/ true);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
// Run the graph.
|
||||
csession.SetInputs({{feed, Int32Tensor(17)}});
|
||||
csession.SetOutputs({op});
|
||||
csession.Run(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_Tensor* out = csession.output_tensor(0);
|
||||
ASSERT_TRUE(out != nullptr);
|
||||
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
||||
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
||||
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
|
||||
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
|
||||
EXPECT_EQ(-17, *output_contents);
|
||||
|
||||
// Clean up
|
||||
csession.CloseAndDelete(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_DeleteFunction(func);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI, UpdateEdge) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// Make two scalar constants.
|
||||
TF_Operation* one = ScalarConst(1, graph, s, "one");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
TF_Operation* two = ScalarConst(2, graph, s, "two");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Add oper.
|
||||
TF_Operation* add = Add(one, two, graph, s, "add");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Add another oper to the graph.
|
||||
TF_Operation* neg = Neg(add, graph, s, "neg");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
NodeDef node_def_neg;
|
||||
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||
EXPECT_EQ(string("add"), node_def_neg.input(0));
|
||||
|
||||
// update edge of neg
|
||||
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
|
||||
|
||||
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
|
||||
|
||||
// Clean up
|
||||
TF_DeleteGraph(graph);
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
/*
|
||||
TODO(skyewm): this test currently DCHECKs, change to bad status
|
||||
|
||||
|
@ -1,13 +1,23 @@
|
||||
# Experimental extensions to the C API for eager execution of kernels.
|
||||
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_libtpu",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_cuda_library",
|
||||
"tfe_xla_copts",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "internal_tfrt_deps")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
@ -31,7 +41,7 @@ tf_cuda_library(
|
||||
"c_api_unified_experimental.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
@ -72,13 +82,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
@ -95,7 +98,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
] + internal_tfrt_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -109,11 +112,16 @@ filegroup(
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
"dlpack.h",
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tape.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_context_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
"tfe_op_attrs_internal.h",
|
||||
@ -172,27 +180,20 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
"gradients.cc",
|
||||
"gradients_internal.h",
|
||||
],
|
||||
name = "tracing_utils",
|
||||
srcs = ["tracing_utils.cc"],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
"tracing_utils.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_operation",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
@ -228,10 +229,10 @@ tf_cuda_cc_test(
|
||||
"gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
@ -242,7 +243,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:array_grad",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
@ -256,6 +258,46 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_util",
|
||||
srcs = [
|
||||
"gradients_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients_util.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients_testutil",
|
||||
srcs = [
|
||||
@ -272,17 +314,93 @@ cc_library(
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
":gradients_util",
|
||||
":tape",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradient_checker",
|
||||
srcs = [
|
||||
"gradient_checker.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradient_checker.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":gradients_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gradient_checker_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"gradient_checker_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradient_checker",
|
||||
":gradients_internal",
|
||||
":gradients_util",
|
||||
":mnist_gradients_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mnist_gradients_test",
|
||||
size = "small",
|
||||
@ -290,19 +408,16 @@ tf_cuda_cc_test(
|
||||
"mnist_gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"nomac",
|
||||
"notap", # TODO(b/166150182): Enable
|
||||
"no_oss", # TODO(b/166150182): Enable
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":gradients_util",
|
||||
":mnist_gradients_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
@ -526,6 +641,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "tfe_tensorhandle_internal_hdrs_only",
|
||||
extra_deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_test_util",
|
||||
testonly = 1,
|
||||
@ -539,6 +667,8 @@ tf_cuda_library(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -553,7 +683,6 @@ tf_cuda_cc_test(
|
||||
"c_api_debug_test.cc",
|
||||
"c_api_test.cc",
|
||||
],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
@ -608,7 +737,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
@ -641,7 +769,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
@ -660,7 +787,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan", # leaks gRPC server instances
|
||||
@ -694,7 +820,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
@ -729,7 +854,7 @@ tf_cuda_library(
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
@ -801,7 +926,6 @@ tf_cuda_cc_test(
|
||||
"c_api_experimental_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
@ -814,6 +938,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -825,7 +950,6 @@ tf_cuda_cc_test(
|
||||
"c_api_unified_experimental_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
@ -834,6 +958,7 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
@ -943,7 +1068,13 @@ filegroup(
|
||||
"c_api_unified_experimental_eager.cc",
|
||||
"c_api_unified_experimental_graph.cc",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
"gradient_checker.cc",
|
||||
"gradient_checker.h",
|
||||
"gradients.cc", # Uses RTTI.
|
||||
"gradients_util.cc",
|
||||
"gradients_util.h",
|
||||
"tracing_utils.h",
|
||||
"tracing_utils.cc",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt };
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt };
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -51,9 +51,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -629,21 +626,30 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
"targets will fail.";
|
||||
}
|
||||
} else {
|
||||
// The master's context_view_id will be incremented by one
|
||||
// the UpdateRemoteMaster call later. We want all new workers and
|
||||
// existing workers to also have the updated context_view_id, so
|
||||
// we must set their context_view_id to the existing master's
|
||||
// context_view_id + 1.
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
if (sg.ok()) {
|
||||
// Create remote contexts on the newly added workers only if the master
|
||||
// has collected all device information from them (i.e., the
|
||||
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
|
||||
// GetAllRemoteDevices can still fail even with RPCs configured to wait
|
||||
// until the remote workers to become alive. If the master creates remote
|
||||
// contexts on the workers whose devices are still not collected, those
|
||||
// workers will be treated as existing workers subsequently, so the master
|
||||
// will never get devices from them even with retrying UpdateServerDef.
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
}
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
// The master's context_view_id will be incremented by one in the
|
||||
// UpdateRemoteMaster call later. We want existing workers to also have
|
||||
// the updated context_view_id, so we must set their context_view_id to
|
||||
// the master's current context_view_id + 1.
|
||||
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
@ -723,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
@ -745,10 +751,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
/*device_mgr_owned*/ true, r));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -851,20 +855,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers);
|
||||
if (!status->status.ok()) {
|
||||
LOG(ERROR) << "Failed to get client cache for remote workers.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO(yuefengz): support partially specified `worker_name`.
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
|
||||
status->status = context->GetClient(worker_name, &eager_client);
|
||||
if (!status->status.ok()) {
|
||||
return false;
|
||||
}
|
||||
@ -911,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalDevicePlacementPolicy(
|
||||
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
|
||||
@ -922,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
context->GetDevicePlacementPolicy());
|
||||
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
@ -1149,26 +1138,23 @@ void TFE_DeleteOp(TFE_Op* op) {
|
||||
tensorflow::unwrap(op)->Release();
|
||||
}
|
||||
|
||||
const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::unwrap(op)->Name().c_str();
|
||||
}
|
||||
|
||||
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::wrap(
|
||||
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::unwrap(op)->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
#else
|
||||
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
||||
"built with XLA support.";
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
|
||||
}
|
||||
@ -1181,6 +1167,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::unwrap(op)->GetInputs().size();
|
||||
}
|
||||
|
||||
extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
|
||||
TF_Status* status) {
|
||||
return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status) {
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
@ -1430,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->FindFunctionDef(name) != nullptr;
|
||||
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@ -1486,7 +1475,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
|
||||
tensorflow::unwrap(ctx)->EndStep();
|
||||
}
|
||||
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
|
||||
return tensorflow::wrap(
|
||||
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
||||
}
|
||||
@ -1551,8 +1540,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
||||
TFE_DeleteOp(func_op);
|
||||
} break;
|
||||
case tensorflow::AttrValue::kList:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case tensorflow::AttrValue::kList: {
|
||||
// String
|
||||
if (const int s_size = default_value.list().s_size()) {
|
||||
absl::InlinedVector<const void*, 4> values_vector;
|
||||
absl::InlinedVector<size_t, 4> lengths_vector;
|
||||
for (int i = 0; i < s_size; ++i) {
|
||||
const string& v = default_value.list().s(i);
|
||||
values_vector.push_back(v.data());
|
||||
lengths_vector.push_back(v.size());
|
||||
}
|
||||
TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
|
||||
lengths_vector.data(), s_size);
|
||||
}
|
||||
|
||||
// Int
|
||||
if (const int i_size = default_value.list().i_size()) {
|
||||
absl::InlinedVector<int64_t, 4> i_vector;
|
||||
for (int i = 0; i < i_size; ++i) {
|
||||
i_vector.push_back(default_value.list().i(i));
|
||||
}
|
||||
TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
|
||||
}
|
||||
// Float
|
||||
if (const int f_size = default_value.list().f_size()) {
|
||||
absl::InlinedVector<float, 4> f_vector;
|
||||
for (int i = 0; i < f_size; ++i) {
|
||||
f_vector.push_back(default_value.list().f(i));
|
||||
}
|
||||
TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
|
||||
}
|
||||
// Bool
|
||||
if (const int b_size = default_value.list().b_size()) {
|
||||
absl::InlinedVector<unsigned char, 4> b_vector;
|
||||
for (int i = 0; i < b_size; i++) {
|
||||
b_vector.push_back(default_value.list().b(i));
|
||||
}
|
||||
TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
|
||||
}
|
||||
// Type
|
||||
if (const int type_size = default_value.list().type_size()) {
|
||||
absl::InlinedVector<unsigned int, 4> type_vector;
|
||||
for (int i = 0; i < type_size; ++i) {
|
||||
type_vector.push_back(default_value.list().type(i));
|
||||
}
|
||||
TFE_OpSetAttrTypeList(
|
||||
op, attr_name,
|
||||
reinterpret_cast<const TF_DataType*>(type_vector.data()),
|
||||
type_size);
|
||||
}
|
||||
|
||||
// Rest are not supported.
|
||||
if (default_value.list().shape_size() > 0 ||
|
||||
default_value.list().func_size() > 0 ||
|
||||
default_value.list().tensor_size() > 0) {
|
||||
TF_SetStatus(
|
||||
status, TF_UNIMPLEMENTED,
|
||||
tensorflow::strings::StrCat("Unable to get setfor default value: ",
|
||||
default_value.DebugString())
|
||||
.data());
|
||||
}
|
||||
} break;
|
||||
case tensorflow::AttrValue::kTensor:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case tensorflow::AttrValue::kPlaceholder:
|
||||
@ -1612,19 +1660,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(tensorflow::EagerOperation* op,
|
||||
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
|
||||
tensorflow::TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> inputs;
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
||||
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
@ -1634,10 +1675,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto inp : inputs) {
|
||||
TFE_DeleteTensorHandle(inp);
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
|
@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
} TFE_ContextDevicePlacementPolicy;
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
|
||||
|
||||
// Sets the default execution mode (sync/async). Note that this can be
|
||||
// overridden per thread using TFE_ContextSetExecutorForThread.
|
||||
@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op;
|
||||
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
|
||||
const char* op_or_function_name,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
|
||||
|
||||
// Returns the op or function name `op` will execute.
|
||||
//
|
||||
// The returned string remains valid throughout the lifetime of 'op'.
|
||||
TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
|
||||
TF_Status* status);
|
||||
// The returned string remains valid throughout the lifetime of 'op'.
|
||||
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
|
||||
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
|
||||
// When 'enable' is set to 1, and if TensorFlow library is built with XLA
|
||||
// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
|
||||
//
|
||||
// If the library is not built with XLA support, this call would be a no-op.
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
|
||||
unsigned char enable);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input,
|
||||
TF_Status* status);
|
||||
|
||||
@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op,
|
||||
int num_inputs,
|
||||
TF_Status* status);
|
||||
|
||||
// Fetches the current number of inputs attached to `op`.
|
||||
//
|
||||
// Does not use the operation's definition to determine how many inputs should
|
||||
// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an
|
||||
// already-finalized operation.
|
||||
//
|
||||
// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat
|
||||
// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a
|
||||
// particular named input list, which may only be part of the op's inputs).
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
// Returns a borrowed reference to one of `op`'s inputs. Use
|
||||
// `TFE_TensorHandleCopySharingTensor` to make a new reference.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op,
|
||||
int index,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op,
|
||||
const char* attr_name,
|
||||
unsigned char* is_list,
|
||||
|
@ -22,9 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
auto* device = absl::get<tensorflow::Device*>(handle->device());
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
|
||||
if (xla_device != nullptr) {
|
||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||
xla_device->metadata().padded_shape_fn();
|
||||
xla::Shape padded_shape;
|
||||
status->status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<tensorflow::int64> shape_to_log =
|
||||
TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
status->status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
<< padded_shape.DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
if (padded_shape.IsTuple()) {
|
||||
if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
|
||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||
// 64bit complex numbers).
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// shape0 is not a const& because we will assign it to padded_shape below.
|
||||
// It is illegal to assign a part of a message to itself.
|
||||
xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
|
||||
const xla::Shape& shape1 =
|
||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should not contain nested tuples. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Since the only case we handle here are two equal subshapes, we
|
||||
// simply return one of them. The caller will interpret it as this
|
||||
// shape directly storing the 64bit types. This approximation is good
|
||||
// enough for this API's debugging use case.
|
||||
padded_shape = shape0;
|
||||
}
|
||||
|
||||
int rank = padded_shape.dimensions_size();
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
dev_dims.reserve(rank);
|
||||
if (rank == 1) {
|
||||
// Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
|
||||
dev_dims.push_back(padded_shape.dimensions(0));
|
||||
} else {
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i);
|
||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||
}
|
||||
}
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<tensorflow::int64> dev_dims =
|
||||
TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
|
@ -121,25 +121,6 @@ string AddVariablesFunction() {
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
||||
CHECK_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
||||
bool initialized = false;
|
||||
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
EXPECT_EQ(initialized, true);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(is_initialized[0]);
|
||||
TFE_DeleteOp(op);
|
||||
delete status;
|
||||
}
|
||||
|
||||
void TestFunctionWithPackedInput(const bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
@ -182,9 +163,8 @@ void TestFunctionWithPackedInput(const bool remote) {
|
||||
|
||||
// Add a sync point in order to make sure that variables have been initialized
|
||||
// before the function execution starts.
|
||||
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
||||
VarIsInitialized(ctx, h1);
|
||||
VarIsInitialized(ctx, h2);
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
// When remote is false, function device is placed on task0. Handle types are
|
||||
@ -396,6 +376,8 @@ TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) {
|
||||
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle, nullptr);
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string function_def = VariableAddFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
@ -517,6 +499,8 @@ void TestDistributedFunctionCancellation(bool inject_error) {
|
||||
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle, nullptr);
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string function_def = inject_error ? VariableAddFunctionWithGraphError()
|
||||
: VariableAddFunction();
|
||||
@ -561,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
// TODO(b/170399182): Update test once an alternative to using the function
|
||||
// optimization hook is in place.
|
||||
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
|
@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||
@ -486,29 +482,6 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
||||
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
||||
TFE_ContextMirroringPolicy policy) {
|
||||
options->mirroring_policy = policy;
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
|
||||
// Note: this function looks up a thread local policy. So it should be called in
|
||||
// the appropriate client thread. In particular, in async mode, it may not be
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
|
||||
bool lazy_copy) {
|
||||
options->lazy_remote_inputs_copy = lazy_copy;
|
||||
@ -567,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
}
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return new TFE_Executor(&context->Executor());
|
||||
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
context->HostCPU()->parsed_name());
|
||||
tensorflow::unwrap(ctx)->HostCPUParsedName());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
@ -595,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
@ -666,14 +631,26 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::unwrap(h)->DeviceType(&status->status);
|
||||
}
|
||||
|
||||
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
return tensorflow::unwrap(h)->DeviceId(&status->status);
|
||||
}
|
||||
|
@ -265,33 +265,6 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
||||
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with internal copy of enum in eager/context.h.
|
||||
typedef enum TFE_ContextMirroringPolicy {
|
||||
// Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
|
||||
// copies with their own lifetime.
|
||||
TFE_MIRRORING_NONE = 0,
|
||||
// Mirroring any remote tensor handles, associating them with the lifetime of
|
||||
// the local TensorHandle.
|
||||
TFE_MIRRORING_ALL = 1,
|
||||
} TFE_ContextMirroringPolicy;
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetMirroringPolicy(
|
||||
TFE_ContextOptions*, TFE_ContextMirroringPolicy);
|
||||
|
||||
// Sets a thread-local mirroring policy. After this call, other calls to
|
||||
// TFE_Execute in the same thread will use the mirroring policy specified here
|
||||
// instead of the mirroring policy used to construct the context. This has no
|
||||
// effect on the mirroring policy used by other program threads.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context*, TFE_ContextMirroringPolicy);
|
||||
|
||||
// Returns the mirroring policy to be used by this context in the current
|
||||
// thread.
|
||||
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context*);
|
||||
|
||||
// Sets whether to copy the remote inputs of a function lazily.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
TFE_ContextOptions*, bool lazy_copy);
|
||||
@ -441,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a reference to `op`'s attributes. The returned reference is only valid
|
||||
// while `op` is alive.
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
|
||||
TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
@ -462,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 2
|
||||
// TODO(b/166642410): It would be nice, for custom devices and for other users,
|
||||
// to have a non-string representation of devices (TF_Device) extracted from
|
||||
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 3
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
@ -481,9 +458,16 @@ typedef struct TFE_CustomDevice {
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
void (*execute)(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
//
|
||||
// Arguments provide enough information to reconstruct the original `TFE_Op`,
|
||||
// or construct a transformed version, by inspecting the passed `op`.
|
||||
//
|
||||
// TFE_OpGetDevice(op) records the original placement of the operation. It may
|
||||
// be an empty string if no device was explicitly requested, but will
|
||||
// otherwise be the name of this custom device. Ops are placed onto a custom
|
||||
// device if any of their inputs are on that custom device, but custom devices
|
||||
// are free to set a bad status in order to require explicit placement.
|
||||
void (*execute)(const TFE_Op* op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
@ -569,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the device type of the operation that produced `h`.
|
||||
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
// Returns the device ID of the operation that produced `h`.
|
||||
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST(CAPI, Function_ident_XLA_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
TF_OperationDescription* arg_descr =
|
||||
TF_NewOperation(function_graph, "Placeholder", "arg");
|
||||
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_OperationDescription* id_descr =
|
||||
TF_NewOperation(function_graph, "Identity", "id");
|
||||
TF_SetAttrType(id_descr, "T", TF_INT32);
|
||||
TF_AddInput(id_descr, {arg, 0});
|
||||
TF_Operation* id = TF_FinishOperation(id_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_Output input{arg, 0};
|
||||
TF_Output output{id, 0};
|
||||
TF_Function* fn =
|
||||
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
|
||||
&output, nullptr, nullptr, "test", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteGraph(function_graph);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_ContextAddFunction(ctx, fn, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteFunction(fn);
|
||||
|
||||
for (bool async : {false, true, false}) {
|
||||
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_Executor* executor = TFE_NewExecutor(async);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
|
||||
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
|
||||
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, h, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
|
||||
// Now run it via XLA.
|
||||
TFE_OpSetXLACompilation(op, true);
|
||||
|
||||
std::vector<TFE_TensorHandle*> result;
|
||||
result.push_back(nullptr);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, result.data(), &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(num_retvals, 1);
|
||||
|
||||
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
|
||||
TFE_ContextSetExecutorForThread(ctx, old_executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteExecutor(old_executor);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
void Executor_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -491,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleNullptr) {
|
||||
TFE_TensorHandle* h = nullptr;
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_type, nullptr);
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
int device_id = TFE_TensorHandleDeviceID(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_id, -1);
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
|
||||
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
|
||||
|
||||
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
TFE_DeleteOp(shape_op);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleDefaults) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
|
||||
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
|
||||
h_default, ctx, "/device:CPU:0", status.get());
|
||||
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
|
||||
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
|
||||
|
||||
TFE_DeleteTensorHandle(h_default);
|
||||
TFE_DeleteTensorHandle(h_cpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -32,7 +32,6 @@ struct TFE_ContextOptions {
|
||||
bool async = false;
|
||||
TFE_ContextDevicePlacementPolicy device_placement_policy{
|
||||
TFE_DEVICE_PLACEMENT_SILENT};
|
||||
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
|
||||
// If true, lazily copy the remote inputs of a function to the target devices.
|
||||
bool lazy_remote_inputs_copy = true;
|
||||
// If true, use TFRT backend
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
@ -876,89 +877,6 @@ TEST(CAPI, Execute_Min_CPU) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
void Execute_MatMul_XLA_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
|
||||
TFE_OpSetXLACompilation(matmul, true);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
// Running a primitive TF operator via XLA is not yet supported.
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
|
||||
TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
|
||||
|
||||
void Execute_Min_XLA_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = MinOp(ctx, input, axis);
|
||||
|
||||
TFE_OpSetXLACompilation(minOp, true);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(minOp, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(minOp);
|
||||
TFE_DeleteTensorHandle(input);
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float output[2] = {0};
|
||||
EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
|
||||
memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(1, output[0]);
|
||||
EXPECT_EQ(3, output[1]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
|
||||
TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
void ExecuteWithTracing(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1274,6 +1192,68 @@ TEST(CAPI, StringAttributes) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Same test as above, expect use SetOpAttrValueScalar to set attrs.
|
||||
TEST(CAPI, TestTFE_SetOpAttrs) {
|
||||
// Test that TFE_OpSetAttrString doesn't hold on to the value after it
|
||||
// returns.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::vector<int64_t> dims(4, 1);
|
||||
TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* tensor =
|
||||
TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
|
||||
float tensor_data[] = {1};
|
||||
memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
|
||||
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, tensor_handle, status);
|
||||
TF_DeleteTensor(tensor);
|
||||
TFE_DeleteTensorHandle(tensor_handle);
|
||||
|
||||
tensorflow::AttrValue i_list_values;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
i_list_values.mutable_list()->add_i(1);
|
||||
}
|
||||
SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
|
||||
SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
|
||||
|
||||
tensorflow::AttrValue padding_value;
|
||||
*padding_value.mutable_s() = "VALID";
|
||||
tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
|
||||
|
||||
tensorflow::AttrValue data_format_value;
|
||||
*data_format_value.mutable_s() = "NHWC";
|
||||
tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
|
||||
status);
|
||||
|
||||
TFE_OpSetAttrType(op, "T", TF_FLOAT);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &retvals[0], &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
|
||||
tensor = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(4, TF_TensorByteSize(tensor));
|
||||
TF_DeleteTensor(tensor);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -1620,4 +1600,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
// Needs to work with a const TFE_Op since custom devices should not modify the
|
||||
// op they are called with.
|
||||
TFE_Op* CloneOp(const TFE_Op* other) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Context* context = TFE_OpGetContext(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char* op_name = TFE_OpGetName(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Op* ret = TFE_NewOp(context, op_name, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char* device = TFE_OpGetDevice(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetDevice(ret, device, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
|
||||
int num_inputs = TFE_OpGetFlatInputCount(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
for (int input_index = 0; input_index < num_inputs; ++input_index) {
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(ret, input, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpRecreation) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Clone an op with attributes and a device set.
|
||||
TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetDevice(original_var_op,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Op* cloned = CloneOp(original_var_op);
|
||||
|
||||
EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
std::string(TFE_OpGetDevice(cloned, status)));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* ret;
|
||||
TFE_Execute(cloned, &ret, &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(ret);
|
||||
|
||||
// Clone an op with inputs and no device set.
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInputList(original_identity, inputs, 2, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Op* cloned_identity = CloneOp(original_identity);
|
||||
EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
|
||||
TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
|
||||
num_retvals = 2;
|
||||
TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(input1);
|
||||
TFE_DeleteTensorHandle(input2);
|
||||
TFE_DeleteTensorHandle(identity_ret[0]);
|
||||
TFE_DeleteTensorHandle(identity_ret[1]);
|
||||
|
||||
TFE_DeleteOp(cloned_identity);
|
||||
TFE_DeleteOp(original_identity);
|
||||
TFE_DeleteOp(original_var_op);
|
||||
TFE_DeleteOp(cloned);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -17,12 +17,16 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::tstring;
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
float data[] = {value};
|
||||
@ -36,6 +40,19 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
|
||||
const tensorflow::tstring& value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
|
||||
tstring* data = static_cast<tstring*>(TF_TensorData(t));
|
||||
*data = value;
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
|
||||
int data[] = {value};
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -28,6 +29,10 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
|
||||
// Return a tensor handle containing a bool scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
|
||||
|
||||
// Return a tensor handle containing a tstring scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
|
||||
const tensorflow::tstring& value);
|
||||
|
||||
// Return a tensor handle containing a 2x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
|
@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() {
|
||||
return *factories;
|
||||
}
|
||||
|
||||
static const char* default_factory = "<unset>";
|
||||
static tracing::FactoryFunction default_factory;
|
||||
|
||||
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
assert((!GetFactories().count(name)) ||
|
||||
@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
GetFactories()[name] = factory;
|
||||
}
|
||||
|
||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||
|
||||
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
auto entry = GetFactories().find(default_factory);
|
||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||
Status SetDefaultTracingEngine(const char* name) {
|
||||
auto entry = GetFactories().find(name);
|
||||
if (entry != GetFactories().end()) {
|
||||
default_factory = GetFactories().find(name)->second;
|
||||
return Status::OK();
|
||||
}
|
||||
string msg = absl::StrCat(
|
||||
"No tracing engine factory has been registered with the key '",
|
||||
default_factory, "' (available: ");
|
||||
"No tracing engine factory has been registered with the key '", name,
|
||||
"' (available: ");
|
||||
// Ensure deterministic (sorted) order in the error message
|
||||
std::set<string> factories_sorted;
|
||||
for (const auto& factory : GetFactories())
|
||||
@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
}
|
||||
msg += ")";
|
||||
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return errors::InvalidArgument(msg.c_str());
|
||||
}
|
||||
|
||||
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
if (default_factory) {
|
||||
return default_factory(fn_name, s);
|
||||
}
|
||||
Set_TF_Status_from_Status(
|
||||
s, errors::FailedPrecondition("default_factory is nullptr"));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext;
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
using tensorflow::tracing::TracingTensorHandle;
|
||||
|
||||
void TF_SetTracingImplementation(const char* name) {
|
||||
SetDefaultTracingEngine(name);
|
||||
void TF_SetTracingImplementation(const char* name, TF_Status* s) {
|
||||
Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
|
||||
}
|
||||
|
||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||
|
@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
// This allows the client to swap the implementation of the tracing engine.
|
||||
// Any future call to TF_CreateFunction will use the implementation defined
|
||||
// here.
|
||||
void TF_SetTracingImplementation(const char* name);
|
||||
void TF_SetTracingImplementation(const char* name, TF_Status*);
|
||||
|
||||
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||
|
@ -365,9 +365,10 @@ class GraphContext : public TracingContext {
|
||||
}
|
||||
|
||||
auto s = TF_NewStatus();
|
||||
func->func = TF_GraphToFunction(
|
||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||
func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
|
||||
inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, name_.data(), s);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
||||
TF_DeleteStatus(s);
|
||||
*f = func.release();
|
||||
@ -391,7 +392,7 @@ class GraphContext : public TracingContext {
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
std::vector<TF_Output> inputs_;
|
||||
const char* name_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
// Register the tracing implemented in this file as the default tracing engine.
|
||||
static bool register_tracing = [] {
|
||||
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||
SetDefaultTracingEngine("graphdef");
|
||||
SetDefaultTracingEngine("graphdef").IgnoreError();
|
||||
return true;
|
||||
}();
|
||||
|
||||
|
@ -120,7 +120,7 @@ class TracingContext : public AbstractContext {
|
||||
};
|
||||
|
||||
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
void SetDefaultTracingEngine(const char* name);
|
||||
Status SetDefaultTracingEngine(const char* name);
|
||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||
FactoryFunction factory);
|
||||
} // namespace tracing
|
||||
|
@ -22,10 +22,15 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
using tensorflow::Status;
|
||||
using tensorflow::string;
|
||||
using tensorflow::TF_StatusPtr;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -37,7 +42,10 @@ class UnifiedCAPI
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
|
||||
ASSERT_FALSE(arrived);
|
||||
@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
bool executed = false;
|
||||
const char* custom_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom_device_name,
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom0,
|
||||
/*strict_scope_placement=*/false, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom1,
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||
|
||||
// Custom device: mix of custom/physical fails.
|
||||
// Custom device: mix of custom/physical places the op on the custom device.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||
EXPECT_TRUE(executed);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Explicit placement still forces the op onto the requested device
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
EXPECT_FALSE(executed);
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
|
||||
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
|
||||
// configured to do)
|
||||
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
EXPECT_FALSE(executed);
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(
|
||||
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
@ -33,6 +33,9 @@ struct LoggingDevice {
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
// If true, only explicit op placements are accepted. If false, uses
|
||||
// type-based dispatch.
|
||||
bool strict_scope_placement;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
const char* requested_placement = TFE_OpGetDevice(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
if (dev->strict_scope_placement && *requested_placement == '\0') {
|
||||
TF_SetStatus(s, TF_INTERNAL,
|
||||
"Ops must be placed on the device explicitly, or their inputs "
|
||||
"first copied to other devices.");
|
||||
return;
|
||||
}
|
||||
TFE_Context* context = TFE_OpGetContext(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const char* operation_name = TFE_OpGetName(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
|
||||
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) {
|
||||
} // namespace
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
bool strict_scope_placement, bool* arrived_flag,
|
||||
bool* executed_flag, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
device->strict_scope_placement = strict_scope_placement;
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
logging_device->device_name = name;
|
||||
logging_device->underlying_device =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
logging_device->strict_scope_placement = true;
|
||||
*device_info = reinterpret_cast<void*>(logging_device);
|
||||
}
|
||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status);
|
||||
bool strict_scope_placement, bool* arrived_flag,
|
||||
bool* executed_flag, TF_Status* status);
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info);
|
||||
|
@ -109,7 +109,8 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||
// Gets DLPack's DLContext from eager tensor handle.
|
||||
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
const char* device_name =
|
||||
tensorflow::unwrap(h)->BackingDeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
@ -248,21 +249,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||
}
|
||||
|
||||
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
auto tf_dlm_context = GetDlContext(h, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
|
||||
auto tf_dlm_type = GetDlDataType(data_type, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||
|
||||
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
|
||||
int ndim = tensor->dims();
|
||||
dlm_tensor->dl_tensor.ndim = ndim;
|
||||
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||
dlm_tensor->dl_tensor.data = tf_dlm_data;
|
||||
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
|
||||
|
||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||
@ -275,13 +291,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||
}
|
||||
|
||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||
dlm_tensor->dl_tensor.shape = shape_arr->data();
|
||||
// There are two ways to represent compact row-major data
|
||||
// 1) nullptr indicates tensor is compact and row-majored.
|
||||
// 2) fill in the strides array as the real case for compact row-major data.
|
||||
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||
// argument properly.
|
||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||
dlm_tensor->dl_tensor.strides = stride_arr->data();
|
||||
|
||||
dlm_tensor->dl_tensor.byte_offset =
|
||||
0; // TF doesn't handle the strides and byte_offsets here
|
||||
return static_cast<void*>(dlm_tensor);
|
||||
|
201
tensorflow/c/eager/gradient_checker.cc
Normal file
201
tensorflow/c/eager/gradient_checker.cc
Normal file
@ -0,0 +1,201 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradient_checker.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
using namespace std;
|
||||
|
||||
// ================== Helper functions =================
|
||||
|
||||
// Fills data with values [start,end) with given step size.
|
||||
void Range(vector<int>* data, int start, int end, int step = 1) {
|
||||
for (int i = start; i < end; i += step) {
|
||||
(*data)[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns AbstractTensorHandlePtr containing [0, ..., n-1].
|
||||
AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) {
|
||||
vector<int> vals(n);
|
||||
int64_t vals_shape[] = {n};
|
||||
Range(&vals, 0, n);
|
||||
AbstractTensorHandlePtr r =
|
||||
GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1);
|
||||
return r;
|
||||
}
|
||||
|
||||
// Fills out_dims with the dimensions of the given tensor.
|
||||
void GetDims(const TF_Tensor* t, int64_t* out_dims) {
|
||||
int num_dims = TF_NumDims(t);
|
||||
for (int i = 0; i < num_dims; i++) {
|
||||
out_dims[i] = TF_Dim(t, i);
|
||||
}
|
||||
}
|
||||
|
||||
// Runs model as is if output is a scalar,
|
||||
// else sums the output tensor before returning.
|
||||
Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
|
||||
absl::Span<AbstractTensorHandle*> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
bool use_function) {
|
||||
GradientRegistry registry;
|
||||
std::vector<AbstractTensorHandle*> model_outputs(1);
|
||||
|
||||
// Run the model.
|
||||
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
|
||||
absl::MakeSpan(model_outputs), use_function,
|
||||
registry));
|
||||
AbstractTensorHandle* model_out = model_outputs[0];
|
||||
|
||||
TF_Tensor* model_out_tensor;
|
||||
TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor));
|
||||
int num_dims_out = TF_NumDims(model_out_tensor);
|
||||
|
||||
// If the output is a scalar, then return the scalar output
|
||||
if (num_dims_out == 0) {
|
||||
outputs[0] = model_out;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Else, reduce sum the output to get a scalar
|
||||
|
||||
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
|
||||
AbstractTensorHandlePtr sum_dims =
|
||||
GetRangeTensorHandleUtil(ctx, num_dims_out);
|
||||
|
||||
// Reduce sum the output on all dimensions.
|
||||
std::vector<AbstractTensorHandle*> sum_inputs(2);
|
||||
sum_inputs[0] = model_out;
|
||||
sum_inputs[1] = sum_dims.get();
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs),
|
||||
absl::MakeSpan(model_outputs), "sum_output"));
|
||||
outputs[0] = model_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
// ========================= End Helper Functions==============================
|
||||
|
||||
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
absl::Span<AbstractTensorHandle*> inputs,
|
||||
int input_index, bool use_function,
|
||||
AbstractTensorHandle** numerical_grad) {
|
||||
AbstractTensorHandle* theta =
|
||||
inputs[input_index]; // parameter we are grad checking
|
||||
|
||||
// Convert from AbstractTensor to TF_Tensor.
|
||||
TF_Tensor* theta_tensor;
|
||||
TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor));
|
||||
|
||||
// Get number of elements and fill data.
|
||||
int num_elems = TF_TensorElementCount(theta_tensor);
|
||||
vector<float> theta_data(num_elems);
|
||||
memcpy(theta_data.data(), TF_TensorData(theta_tensor),
|
||||
TF_TensorByteSize(theta_tensor));
|
||||
|
||||
// Initialize space for the numerical gradient.
|
||||
vector<float> dtheta_approx(num_elems);
|
||||
|
||||
// Get theta shape and store in theta_dims.
|
||||
int num_dims = TF_NumDims(theta_tensor);
|
||||
vector<int64_t> theta_dims(num_dims);
|
||||
GetDims(theta_tensor, theta_dims.data());
|
||||
|
||||
// Initialize auxilary data structures.
|
||||
vector<float> thetaPlus_data(num_elems);
|
||||
vector<float> thetaMinus_data(num_elems);
|
||||
std::vector<AbstractTensorHandle*> f_outputs(1);
|
||||
|
||||
// Numerical Grad Check
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
// Get relative epsilon value
|
||||
float epsilon =
|
||||
std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0
|
||||
AbstractTensorHandlePtr two_eps =
|
||||
GetScalarTensorHandleUtil(ctx, 2 * epsilon);
|
||||
|
||||
// Initialize theta[i] + epsilon.
|
||||
memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
|
||||
TF_TensorByteSize(theta_tensor));
|
||||
thetaPlus_data[i] += epsilon;
|
||||
AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat(
|
||||
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims);
|
||||
|
||||
// Initialize theta[i] - epsilon.
|
||||
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
|
||||
TF_TensorByteSize(theta_tensor));
|
||||
thetaMinus_data[i] -= epsilon;
|
||||
AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat(
|
||||
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims);
|
||||
|
||||
// Get f(theta + eps):
|
||||
inputs[input_index] = thetaPlus.get();
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
|
||||
absl::MakeSpan(f_outputs), use_function));
|
||||
AbstractTensorHandle* fPlus = f_outputs[0];
|
||||
|
||||
// Get f(theta - eps):
|
||||
inputs[input_index] = thetaMinus.get();
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
|
||||
absl::MakeSpan(f_outputs), use_function));
|
||||
AbstractTensorHandle* fMinus = f_outputs[0];
|
||||
|
||||
// Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"));
|
||||
AbstractTensorHandle* fDiff = f_outputs[0];
|
||||
|
||||
// Calculate using the difference quotient definition:
|
||||
// (f(theta + eps) - f(theta - eps)) / (2 * eps).
|
||||
TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()},
|
||||
absl::MakeSpan(f_outputs),
|
||||
"diff_quotient"));
|
||||
AbstractTensorHandle* diff_quotient = f_outputs[0];
|
||||
|
||||
TF_Tensor* grad_tensor;
|
||||
TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor));
|
||||
float grad_data[1];
|
||||
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
|
||||
TF_TensorByteSize(grad_tensor));
|
||||
|
||||
dtheta_approx[i] = grad_data[0];
|
||||
}
|
||||
|
||||
// Populate *numerical_grad with the data from dtheta_approx.
|
||||
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(
|
||||
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
53
tensorflow/c/eager/gradient_checker.h
Normal file
53
tensorflow/c/eager/gradient_checker.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
/* Returns numerical grad inside `dtheta_approx` given `forward` model and
|
||||
* parameter specified by `input_index`.
|
||||
*
|
||||
* I.e. if y = <output of the forward model> and w = inputs[input_index],
|
||||
* this will calculate dy/dw numerically.
|
||||
*
|
||||
* `use_function` indicates whether to use graph mode(true) or eager(false).
|
||||
*
|
||||
* `numerical_grad` is the pointer to the AbstractTensorHandle* which will
|
||||
* hold the numerical gradient data at the end of the function.
|
||||
*/
|
||||
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
absl::Span<AbstractTensorHandle*> inputs,
|
||||
int input_index, bool use_function,
|
||||
AbstractTensorHandle** numerical_grad);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
265
tensorflow/c/eager/gradient_checker_test.cc
Normal file
265
tensorflow/c/eager/gradient_checker_test.cc
Normal file
@ -0,0 +1,265 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradient_checker.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
class GradientCheckerTest
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims[] = {2, 2};
|
||||
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
|
||||
int64_t B_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A =
|
||||
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
AbstractTensorHandlePtr B =
|
||||
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
inputs.push_back(A.get());
|
||||
inputs.push_back(B.get());
|
||||
|
||||
AbstractTensorHandle* grad_approx;
|
||||
Status s = CalcNumericalGrad(
|
||||
ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()), &grad_approx);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* gt;
|
||||
s = GetValue(grad_approx, >);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
|
||||
|
||||
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
|
||||
float tolerance = 1e-2;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(expected_dA[j], result_data[j], tolerance);
|
||||
}
|
||||
TF_DeleteTensor(gt);
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
// Will perform z = x*y.
|
||||
// dz/dx = y
|
||||
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
inputs.push_back(x.get());
|
||||
inputs.push_back(y.get());
|
||||
AbstractTensorHandle* g;
|
||||
|
||||
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
|
||||
/*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()), &g);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* gt;
|
||||
s = GetValue(g, >);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
float result_data[1] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
|
||||
|
||||
ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2);
|
||||
TF_DeleteTensor(gt);
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
|
||||
bool use_function = !std::get<2>(GetParam());
|
||||
if (use_function) {
|
||||
// TODO(b/168850692): Enable this.
|
||||
GTEST_SKIP() << "Can't take gradient of "
|
||||
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
/** Test to show how to use this API with analytical gradients:
|
||||
*
|
||||
* We have `SoftmaxLossGradModel`, which is a wrapper for the
|
||||
* Softmax analytical gradient found in c/experimental/nn_grads.
|
||||
*
|
||||
* We will use the GradientChecker by applying finite differences
|
||||
* to the forward pass wrapped in `SoftmaxModel` and verify that
|
||||
* both the analytical and numerical gradients are relatively
|
||||
* close.
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = scores
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f};
|
||||
int64_t X_dims[] = {3, 3};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 0, 1};
|
||||
int64_t y_dims[] = {3};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
inputs.push_back(X.get());
|
||||
inputs.push_back(y.get());
|
||||
|
||||
// Run analytical gradient and get its data.
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs),
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dX_tensor;
|
||||
s = GetValue(outputs[0], &dX_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float danalytical[9] = {0}; // Contains data from analytical gradient.
|
||||
memcpy(&danalytical[0], TF_TensorData(dX_tensor),
|
||||
TF_TensorByteSize(dX_tensor));
|
||||
|
||||
// Run numerical gradient approximation using the GradientChecker API.
|
||||
AbstractTensorHandle* g; // Will contain numerical approximation data.
|
||||
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs),
|
||||
/*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()), &g);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* gt;
|
||||
s = GetValue(g, >);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
float dnumerical[9] = {0};
|
||||
memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt));
|
||||
|
||||
// Now compare the two implementations:
|
||||
for (int j = 0; j < 9; j++) {
|
||||
ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2);
|
||||
}
|
||||
|
||||
// Only Unref() first output as 2nd is nullptr grad for labels
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(dX_tensor);
|
||||
TF_DeleteTensor(gt);
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, GradientCheckerTest,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, GradientCheckerTest,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -122,14 +122,12 @@ int64 ToId(AbstractTensorHandle* t) {
|
||||
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
|
||||
}
|
||||
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
|
||||
: handle_(handle), ctx_(ctx) {
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
|
||||
handle_->Ref();
|
||||
}
|
||||
TapeTensor::TapeTensor(const TapeTensor& other) {
|
||||
handle_ = other.handle_;
|
||||
handle_->Ref();
|
||||
ctx_ = other.ctx_;
|
||||
}
|
||||
TapeTensor::~TapeTensor() { handle_->Unref(); }
|
||||
|
||||
@ -138,33 +136,7 @@ tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
|
||||
tensorflow::DataType TapeTensor::GetDType() const {
|
||||
return handle_->DataType();
|
||||
}
|
||||
|
||||
AbstractTensorHandle* TapeTensor::OnesLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
|
||||
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
|
||||
|
||||
@ -219,6 +191,23 @@ Status TapeVSpace::CallBackwardFunction(
|
||||
&ctx, incoming_gradients, result);
|
||||
}
|
||||
|
||||
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(
|
||||
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
|
||||
*result = outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
|
||||
return ToId(tensor);
|
||||
@ -226,7 +215,7 @@ int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
|
||||
return TapeTensor(g, ctx_);
|
||||
return TapeTensor(g);
|
||||
}
|
||||
|
||||
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
|
||||
@ -242,6 +231,7 @@ namespace internal {
|
||||
Status Reset(AbstractOperation* op_, const char* op,
|
||||
const char* raw_device_name, ForwardOperation* forward_op_) {
|
||||
forward_op_->op_name = op;
|
||||
forward_op_->attrs.Reset(op);
|
||||
return op_->Reset(op, raw_device_name);
|
||||
}
|
||||
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
|
||||
@ -418,9 +408,14 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
|
||||
forward_op_->outputs.push_back(retvals[i]);
|
||||
}
|
||||
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
|
||||
// attributes. Number type attrs and DataType attrs work fine without this.
|
||||
// Consider getting rid of this and making the behavior between number types
|
||||
// and string consistent.
|
||||
forward_op_->attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t, ctx));
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
tape->RecordOperation(
|
||||
op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
|
@ -80,7 +80,6 @@ struct ForwardOperation {
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
std::vector<AbstractTensorHandle*> outputs;
|
||||
AttrBuilder attrs;
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
// Interface for building default zeros gradients for op outputs which are
|
||||
@ -181,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
|
||||
// allow us to trace the data dependencies between operations and hence compute
|
||||
// gradients.
|
||||
//
|
||||
// This also implements `OnesLike` to create the default
|
||||
// incoming gradients for tensors which do not already have an incoming
|
||||
// gradient.
|
||||
//
|
||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||
// for each op.
|
||||
@ -193,20 +188,19 @@ int64 ToId(AbstractTensorHandle* t);
|
||||
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
|
||||
class TapeTensor {
|
||||
public:
|
||||
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
|
||||
explicit TapeTensor(AbstractTensorHandle* handle);
|
||||
TapeTensor(const TapeTensor& other);
|
||||
~TapeTensor();
|
||||
|
||||
tensorflow::int64 GetID() const;
|
||||
tensorflow::DataType GetDType() const;
|
||||
|
||||
AbstractTensorHandle* OnesLike() const;
|
||||
AbstractTensorHandle* ZerosLike() const;
|
||||
|
||||
AbstractTensorHandle* GetHandle() const;
|
||||
|
||||
private:
|
||||
AbstractTensorHandle* handle_;
|
||||
// The context where OnesLike ops are to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// Vector space for actually computing gradients. Implements methods for calling
|
||||
@ -234,6 +228,10 @@ class TapeVSpace
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
Status BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const override;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TensorId(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
@ -26,7 +27,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
@ -38,84 +41,32 @@ namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
using std::vector;
|
||||
using tensorflow::TF_StatusPtr;
|
||||
using tracing::TracingOperation;
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
|
||||
// AddV2Registerer.
|
||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `exp(inputs[0])` and records it on the tape.
|
||||
Status Exp(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr exp_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(exp_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `IdentityN(inputs)` and records it on the tape.
|
||||
Status IdentityN(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(identity_n_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
|
||||
->SetOpName("my_identity_n"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
|
||||
int num_retvals = outputs.size();
|
||||
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
|
||||
tape, registry);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
@ -128,8 +79,10 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(add_outputs),
|
||||
"Add")); // Compute x+y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -160,8 +113,9 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> exp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -179,6 +133,37 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = sqrt(inputs[0])
|
||||
// return grad(y, {inputs[0]})
|
||||
Status SqrtGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sqrt_output : sqrt_outputs) {
|
||||
sqrt_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
@ -193,8 +178,9 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[1]));
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
|
||||
absl::MakeSpan(identity_n_outputs), registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::IdentityN(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -214,6 +200,73 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = - inputs[0]
|
||||
// return grad(y, {inputs[0]})
|
||||
Status NegGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto neg_output : neg_outputs) {
|
||||
neg_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] - inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status SubGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> sub_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(sub_outputs),
|
||||
"Sub")); // Compute x-y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sub_output : sub_outputs) {
|
||||
sub_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -448,6 +501,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSqrtGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = sqrt(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// Pseudo-code:
|
||||
//
|
||||
@ -507,6 +604,161 @@ TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestNegGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = - x
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, -1.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSubGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x - y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, -1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSetAttrString) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr t;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
t.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
|
||||
/*raw_device_name=*/nullptr, &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
if (isa<TracingOperation>(check_numerics_op.get())) {
|
||||
s = dyn_cast<TracingOperation>(check_numerics_op.get())
|
||||
->SetOpName("check_numerics");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
s = AddInput(check_numerics_op.get(), t.get(), &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
string message = "This is the way!";
|
||||
s = SetAttrString(check_numerics_op.get(), "message", message.data(),
|
||||
message.length(), &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
int num_retvals = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
GradientRegistry registry;
|
||||
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
|
||||
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
|
||||
&num_retvals, &forward_op, tape.get(), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
string read_message;
|
||||
s = forward_op.attrs.Get("message", &read_message);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_EQ(read_message, message);
|
||||
}
|
||||
|
||||
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
|
||||
// supported. It is needed for IdentityN.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
317
tensorflow/c/eager/gradients_util.cc
Normal file
317
tensorflow/c/eager/gradients_util.cc
Normal file
@ -0,0 +1,317 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
using namespace std;
|
||||
|
||||
Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
|
||||
TFE_TensorHandle** result) {
|
||||
float data[] = {value};
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
|
||||
*result = th;
|
||||
TF_DeleteTensor(t);
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
TFE_TensorHandle** result) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
|
||||
*result = th;
|
||||
TF_DeleteTensor(t);
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
TFE_TensorHandle** result) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
|
||||
*result = th;
|
||||
TF_DeleteTensor(t);
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager;
|
||||
TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager;
|
||||
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
|
||||
num_dims, &input_eager));
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
|
||||
int num_dims, AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager;
|
||||
TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
|
||||
num_dims, &input_eager));
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val) {
|
||||
AbstractTensorHandlePtr y;
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = ScalarTensorHandle(ctx, val, &y_raw);
|
||||
if (s.ok()) {
|
||||
y.reset(y_raw);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate) {
|
||||
/* Update weights one by one using gradient update rule:
|
||||
*
|
||||
* w -= lr*grad[w]
|
||||
*
|
||||
* NOTE: assuming learning rate is positive
|
||||
*/
|
||||
|
||||
int num_grads = grads.size();
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
std::string update_str;
|
||||
|
||||
// Negate learning rate for gradient descent
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"neg_lr")); // Compute -lr
|
||||
learning_rate = temp_outputs[0];
|
||||
|
||||
for (int i = 0; i < num_grads; i++) {
|
||||
// Compute dW = -lr * grad(w[i])
|
||||
update_str = "update_mul_" + std::to_string(i);
|
||||
TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str()));
|
||||
|
||||
AbstractTensorHandle* dW = temp_outputs[0];
|
||||
|
||||
// Compute temp = weights[i] + dW
|
||||
update_str = "update_add_" + std::to_string(i);
|
||||
TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str()));
|
||||
|
||||
// Update the weights
|
||||
weights[i] = temp_outputs[0];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
88
tensorflow/c/eager/gradients_util.h
Normal file
88
tensorflow/c/eager/gradients_util.h
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
|
||||
int num_dims, AbstractTensorHandle** tensor);
|
||||
|
||||
// Places data from `t` into *result_tensor.
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data and dims.
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data and dims.
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data.
|
||||
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val);
|
||||
|
||||
// Performs gradient update for each weight using given learning rate.
|
||||
Status UpdateWeights(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>& grads,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
// Runs given model in either graph or eager mode depending on value of
|
||||
// use_function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Builds context and returns inside *ctx.
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -29,8 +29,25 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class EagerExecutor;
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||
enum ContextDevicePlacementPolicy {
|
||||
// Running operations with input tensors on the wrong device will fail.
|
||||
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
// Copy the tensor to the right device but log a warning.
|
||||
DEVICE_PLACEMENT_WARN = 1,
|
||||
// Silently copy the tensor, which has a performance cost since the operation
|
||||
// will be blocked till the copy completes. This is the default policy.
|
||||
DEVICE_PLACEMENT_SILENT = 2,
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
|
||||
// Find and return a added function by its name.
|
||||
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||
|
||||
// Return the ParsedName of Host CPU device.
|
||||
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||
|
||||
// Configure soft device placement policy.
|
||||
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||
|
||||
// Configure device placement policy logging.
|
||||
virtual void SetLogDevicePlacement(bool enable) = 0;
|
||||
|
||||
// Sets the device placement policy for the current thread.
|
||||
virtual void SetThreadLocalDevicePlacementPolicy(
|
||||
ContextDevicePlacementPolicy policy) = 0;
|
||||
// Returns the device placement policy for the current thread.
|
||||
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are legacy features in TF Eager Runtime.
|
||||
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||
// migrated to TFRT.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Clear pending nodes in thread executors and kernel caches.
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Return the Eager Executor for current thread. Please note that Eager
|
||||
// Executor is only used in current TF but not in TFRT.
|
||||
virtual EagerExecutor& Executor() = 0;
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
// Configure graph collection in RunMetadata.
|
||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
// Set stack trace to be used for potential async error reporting.
|
||||
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
|
||||
|
||||
|
@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
virtual const char* DeviceName(Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(Status* status) const = 0;
|
||||
// Returns the device type which created the handle.
|
||||
virtual const char* DeviceType(Status* status) const = 0;
|
||||
// Returns the device ID which created the handle.
|
||||
virtual int DeviceId(Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
@ -33,12 +33,16 @@ namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
using tensorflow::TF_StatusPtr;
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
};
|
||||
|
||||
@ -49,89 +53,10 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyLossRegisterer));
|
||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ========================= Test Util Functions ==============================
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
// =========================== Start Tests ================================
|
||||
|
||||
TEST_P(CppGradients, TestMatMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -465,6 +390,12 @@ TEST_P(CppGradients, TestReluGrad) {
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSoftmaxLossGrad) {
|
||||
bool use_function = !std::get<2>(GetParam());
|
||||
if (use_function) {
|
||||
// TODO(b/168850692): Enable this.
|
||||
GTEST_SKIP() << "Can't take gradient of "
|
||||
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
@ -533,6 +464,12 @@ TEST_P(CppGradients, TestSoftmaxLossGrad) {
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTGrad) {
|
||||
bool use_function = !std::get<2>(GetParam());
|
||||
if (use_function) {
|
||||
// TODO(b/168850692): Enable this.
|
||||
GTEST_SKIP() << "Can't take gradient of "
|
||||
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
@ -603,7 +540,6 @@ TEST_P(CppGradients, TestMNISTGrad) {
|
||||
TF_TensorByteSize(dW1_tensor));
|
||||
|
||||
float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
|
||||
; // dLoss
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
|
||||
}
|
||||
@ -643,7 +579,7 @@ TEST_P(CppGradients, TestScalarMul) {
|
||||
AbstractTensorHandlePtr eta;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
|
||||
Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
eta.reset(x_raw);
|
||||
}
|
||||
@ -681,6 +617,12 @@ TEST_P(CppGradients, TestScalarMul) {
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNIST_Training) {
|
||||
bool use_function = !std::get<2>(GetParam());
|
||||
if (use_function) {
|
||||
// TODO(b/168850692): Enable this.
|
||||
GTEST_SKIP() << "Can't take gradient of "
|
||||
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
@ -733,7 +675,7 @@ TEST_P(CppGradients, TestMNIST_Training) {
|
||||
|
||||
// Set learning rate to be 1e-1
|
||||
AbstractTensorHandle* learning_rate = nullptr;
|
||||
s = TestScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
|
||||
s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Train
|
||||
@ -765,13 +707,13 @@ TEST_P(CppGradients, TestMNIST_Training) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
|
@ -24,136 +24,19 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
using std::vector;
|
||||
using tracing::TracingOperation;
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(mul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(relu_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
||||
// tape.
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* scores = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(sm_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
||||
|
||||
int num_retvals = 2; // returns loss values and backprop
|
||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
//===================== Test Models to run =========================
|
||||
|
||||
@ -169,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -202,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute x*y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(mm_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute x*y.
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -238,8 +124,9 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
* hidden_layer = tf.nn.relu(mm_out_1)
|
||||
* scores = tf.matmul(hidden_layer,W2)
|
||||
* softmax =
|
||||
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) return
|
||||
* scores, softmax
|
||||
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
|
||||
* y_labels)
|
||||
* return scores, softmax
|
||||
*
|
||||
* Use this convention for inputs:
|
||||
*
|
||||
@ -257,24 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W2)); // Watch W2.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs), "relu",
|
||||
registry)); // Compute Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"relu")); // Compute Relu(X*W1)
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // Compute W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
|
||||
"matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||
|
||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||
|
||||
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W1));
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/true,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/true,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch X
|
||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
|
||||
"relu0", registry)); // Relu(X)
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(relu_outputs),
|
||||
"relu0")); // Relu(X)
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
tape->Watch(ToId(W2)); // Watch W1.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
AbstractTensorHandle* mm = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0", registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0"));
|
||||
|
||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss")); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* loss = temp_outputs[0];
|
||||
|
||||
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0", registry)); // Compute eta*A
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0")); // Compute eta*A
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
@ -449,146 +347,69 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* x = inputs[0];
|
||||
AbstractTensorHandle* y = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"mul0")); // Compute x*y
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SoftmaxModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* x = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
|
||||
|
||||
outputs[0] = temp_outputs[0]; // loss values
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ============================= End Models ================================
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate) {
|
||||
/* Update weights one by one using gradient update rule:
|
||||
*
|
||||
* w -= lr*grad[w]
|
||||
*
|
||||
* NOTE: assuming learning rate is positive
|
||||
*/
|
||||
|
||||
Status s;
|
||||
int num_grads = grads.size();
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
std::string update_str;
|
||||
|
||||
// Negate learning rate for gradient descent
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"neg_lr")); // Compute -lr
|
||||
learning_rate = temp_outputs[0];
|
||||
|
||||
for (int i = 0; i < num_grads; i++) {
|
||||
// Compute dW = -lr * grad(w[i])
|
||||
update_str = "update_mul_" + std::to_string(i);
|
||||
s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str());
|
||||
|
||||
AbstractTensorHandle* dW = temp_outputs[0];
|
||||
|
||||
// Compute temp = weights[i] + dW
|
||||
update_str = "update_add_" + std::to_string(i);
|
||||
s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str());
|
||||
|
||||
// Update the weights
|
||||
weights[i] = temp_outputs[0];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
@ -24,50 +26,13 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using namespace tensorflow::gradients;
|
||||
using namespace tensorflow::gradients::internal;
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
||||
// tape.
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// ====================== End Tape Ops ============================
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
@ -121,26 +86,23 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Updates the weights for a neural network given incoming grads and learning
|
||||
// rate
|
||||
Status UpdateWeights(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>& grads,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name);
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
Status SoftmaxModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
@ -103,7 +105,6 @@ cc_library(
|
||||
hdrs = ["parallel_device_testlib.h"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -118,7 +119,6 @@ tf_cc_test(
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
@ -138,7 +138,6 @@ tf_cc_test(
|
||||
args = ["--heap_check=local"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
@ -150,19 +149,3 @@ tf_cc_test(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||
# to TensorFlow by default, just used in a few tests.
|
||||
filegroup(
|
||||
name = "parallel_device_ops_srcs",
|
||||
srcs = ["parallel_device_ops.cc"],
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_ops",
|
||||
srcs = [":parallel_device_ops_srcs"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -136,13 +136,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(parallel_device.DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
std::vector<ParallelTensor*> parallel_inputs;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||
@ -255,28 +248,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
const char* requested_placement = TFE_OpGetDevice(original_op, status);
|
||||
if (*requested_placement == '\0') {
|
||||
TF_SetStatus(
|
||||
status, TF_INTERNAL,
|
||||
"Ops must be placed on the parallel device explicitly, or their inputs "
|
||||
"first un-packed. Got an un-placed op with an input placed on the "
|
||||
"parallel device.");
|
||||
return;
|
||||
}
|
||||
TFE_Context* context = TFE_OpGetContext(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* operation_name = TFE_OpGetName(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
|
||||
|
||||
NamedParallelDevice* named_device =
|
||||
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (named_device->name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
TFE_TensorHandleDevicePointer(inputs[i], status)));
|
||||
TFE_TensorHandleDevicePointer(input, status)));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
typed_inputs.emplace_back(inputs[i]);
|
||||
typed_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
class DeviceThread {
|
||||
public:
|
||||
// Starts a background thread waiting for `StartExecute`.
|
||||
explicit DeviceThread(const std::string& device)
|
||||
explicit DeviceThread(const std::string& device, const bool is_async)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
// If the context's default exector is set to async, re-using that in
|
||||
@ -67,7 +67,7 @@ class DeviceThread {
|
||||
//
|
||||
// TODO(allenl): We should have an async API that works with the
|
||||
// parallel device.
|
||||
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||
executor_(TFE_NewExecutor(is_async)),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
}
|
||||
}
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
|
||||
const bool is_async)
|
||||
: underlying_devices_(devices) {
|
||||
device_threads_.reserve(devices.size());
|
||||
for (int device_index = 0; device_index < devices.size(); ++device_index) {
|
||||
device_threads_.emplace_back(
|
||||
new DeviceThread(devices[device_index].c_str()));
|
||||
new DeviceThread(devices[device_index].c_str(), is_async));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,10 @@ class DeviceThread;
|
||||
// placed on each underlying device.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
// Eager async execution is only supported when remote eager is not in use
|
||||
// (b/157523095).
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices,
|
||||
const bool is_async = false);
|
||||
|
||||
~ParallelDevice();
|
||||
|
||||
|
@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
// Compute the device ID twice and verify the result
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int32_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int32_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -98,6 +99,10 @@ class VSpace {
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) const = 0;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
virtual Status BuildOnesLike(const TapeTensor& t,
|
||||
Gradient** result) const = 0;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
virtual int64 TensorId(Gradient* tensor) const = 0;
|
||||
|
||||
@ -121,7 +126,7 @@ class GradientTape {
|
||||
// functions (and hence the tensors they keep alive). Instead, everything
|
||||
// is deleted in ~GradientTape. Persistent GradientTapes are useful when
|
||||
// users want to compute multiple gradients over the same tape.
|
||||
GradientTape(bool persistent) : persistent_(persistent) {}
|
||||
explicit GradientTape(bool persistent) : persistent_(persistent) {}
|
||||
~GradientTape() {
|
||||
for (const auto& pair : op_tape_) {
|
||||
pair.second.backward_function_deleter(pair.second.backward_function);
|
||||
@ -595,8 +600,10 @@ Status InitialGradients(
|
||||
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
|
||||
if (op_it->second.output_tensor_info[j].GetID() == id) {
|
||||
found = true;
|
||||
(*result)[id].push_back(
|
||||
op_it->second.output_tensor_info[j].OnesLike());
|
||||
Gradient* ones_like = nullptr;
|
||||
TF_RETURN_IF_ERROR(vspace.BuildOnesLike(
|
||||
op_it->second.output_tensor_info[j], &ones_like));
|
||||
(*result)[id].push_back(ones_like);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -611,7 +618,10 @@ Status InitialGradients(
|
||||
// target is also a source.
|
||||
auto source_tensor = sources_that_are_targets.find(id);
|
||||
if (source_tensor != sources_that_are_targets.end()) {
|
||||
(*result)[id].push_back(source_tensor->second.OnesLike());
|
||||
Gradient* ones_like = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
vspace.BuildOnesLike(source_tensor->second, &ones_like));
|
||||
(*result)[id].push_back(ones_like);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -934,7 +944,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
// TODO(allenl): Figure out why using zeros_like everywhere causes issues
|
||||
// for some gradient functions and if there's another way to work around
|
||||
// it (e.g. conds instead of ifs). The value shouldn't really matter.
|
||||
aid = output_tensor.OnesLike();
|
||||
TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid));
|
||||
}
|
||||
if (TF_PREDICT_FALSE(aid == nullptr)) {
|
||||
return tensorflow::errors::Internal(
|
||||
|
37
tensorflow/c/eager/tracing_utils.cc
Normal file
37
tensorflow/c/eager/tracing_utils.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tracing {
|
||||
|
||||
Status MaybeSetOpName(AbstractOperation* op, const char* op_name) {
|
||||
if (isa<TracingOperation>(op)) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(op)->SetOpName(op_name));
|
||||
}
|
||||
if (isa<gradients::TapeOperation>(op)) {
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(
|
||||
dyn_cast<gradients::TapeOperation>(op)->GetBackingOperation(),
|
||||
op_name));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tracing
|
||||
} // namespace tensorflow
|
26
tensorflow/c/eager/tracing_utils.h
Normal file
26
tensorflow/c/eager/tracing_utils.h
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tracing {
|
||||
Status MaybeSetOpName(AbstractOperation*, const char* op_name);
|
||||
} // namespace tracing
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Experimental filesystem C APIs for TensorFlow.
|
||||
# Will be moved in proper place once all filesystems are converted to the
|
||||
# modular framework.
|
||||
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Experimental gcs filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
|
||||
|
||||
@ -29,6 +31,7 @@ cc_library(
|
||||
":gcs_helper",
|
||||
":ram_file_block_cache",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
@ -59,6 +62,7 @@ cc_library(
|
||||
deps = [
|
||||
":cleanup",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
@ -120,20 +121,20 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
|
||||
return -1;
|
||||
}
|
||||
int64_t read;
|
||||
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
|
||||
&read)) {
|
||||
auto content_length = stream.headers().find("content-length");
|
||||
if (content_length == stream.headers().end()) {
|
||||
// When we read a file with offset that is bigger than the actual file size.
|
||||
// GCS will return an empty header (e.g no `content-length` header). In this
|
||||
// case, we will set read to `0` and continue.
|
||||
if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
|
||||
read = 0;
|
||||
} else {
|
||||
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
|
||||
return -1;
|
||||
}
|
||||
read = 0;
|
||||
} else if (!absl::SimpleAtoi(content_length->second, &read)) {
|
||||
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
|
||||
return -1;
|
||||
}
|
||||
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset,
|
||||
read);
|
||||
stream.read(buffer, read);
|
||||
read = stream.gcount();
|
||||
if (read < buffer_size) {
|
||||
@ -146,6 +147,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
|
||||
path, " @ ", offset)
|
||||
.c_str());
|
||||
}
|
||||
TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
return read;
|
||||
@ -259,7 +262,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
|
||||
if (*offset == -1 || *offset == 0) {
|
||||
// UploadFile will automatically switch to resumable upload based on Client
|
||||
// configuration.
|
||||
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object);
|
||||
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object,
|
||||
gcs::Fields("size"));
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
@ -278,15 +282,18 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
|
||||
} else {
|
||||
std::string temporary_object =
|
||||
gcs::CreateRandomPrefixName("tf_writable_file_gcs");
|
||||
auto metadata =
|
||||
gcs_client->UploadFile(outfile->getName(), bucket, temporary_object);
|
||||
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket,
|
||||
temporary_object, gcs::Fields(""));
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(),
|
||||
temporary_object.c_str(), bucket.c_str(), object.c_str());
|
||||
const std::vector<gcs::ComposeSourceObject> source_objects = {
|
||||
{object, {}, {}}, {temporary_object, {}, {}}};
|
||||
metadata = gcs_client->ComposeObject(bucket, source_objects, object);
|
||||
metadata = gcs_client->ComposeObject(bucket, source_objects, object,
|
||||
gcs::Fields("size"));
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
@ -321,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n,
|
||||
"The internal temporary file is not writable.");
|
||||
return;
|
||||
}
|
||||
TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str(), n);
|
||||
gcs_file->sync_need = true;
|
||||
gcs_file->outfile.write(buffer, n);
|
||||
if (!gcs_file->outfile)
|
||||
@ -346,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
|
||||
void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
if (gcs_file->sync_need) {
|
||||
TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
if (!gcs_file->outfile) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not append to the internal temporary file.");
|
||||
@ -353,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
}
|
||||
SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
|
||||
&gcs_file->outfile, gcs_file->gcs_client, status);
|
||||
TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
gcs_file->sync_need = false;
|
||||
} else {
|
||||
@ -361,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
}
|
||||
|
||||
void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
Flush(file, status);
|
||||
}
|
||||
|
||||
void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
if (gcs_file->sync_need) {
|
||||
Flush(file, status);
|
||||
}
|
||||
@ -428,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
|
||||
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
|
||||
max_staleness = value;
|
||||
}
|
||||
TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u",
|
||||
max_bytes, block_size, max_staleness);
|
||||
|
||||
file_block_cache = std::make_unique<RamFileBlockCache>(
|
||||
block_size, max_bytes, max_staleness,
|
||||
@ -504,13 +524,18 @@ void Cleanup(TF_Filesystem* filesystem) {
|
||||
static void UncachedStatForObject(const std::string& bucket,
|
||||
const std::string& object, GcsFileStat* stat,
|
||||
gcs::Client* gcs_client, TF_Status* status) {
|
||||
auto metadata = gcs_client->GetObjectMetadata(bucket, object);
|
||||
auto metadata = gcs_client->GetObjectMetadata(
|
||||
bucket, object, gcs::Fields("generation,size,timeStorageClassUpdated"));
|
||||
if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
stat->generation_number = metadata->generation();
|
||||
stat->base.length = metadata->size();
|
||||
stat->base.mtime_nsec =
|
||||
metadata->time_storage_class_updated().time_since_epoch().count();
|
||||
stat->base.is_directory = object.back() == '/';
|
||||
TF_VLog(1,
|
||||
"Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;",
|
||||
bucket.c_str(), object.c_str(), stat->base.length,
|
||||
stat->generation_number, stat->base.mtime_nsec);
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -545,9 +570,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
|
||||
path, stat.generation_number)) {
|
||||
std::cout
|
||||
<< "File signature has been changed. Refreshing the cache. Path: "
|
||||
<< path;
|
||||
TF_VLog(
|
||||
1,
|
||||
"File signature has been changed. Refreshing the cache. Path: %s",
|
||||
path.c_str());
|
||||
}
|
||||
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
|
||||
} else {
|
||||
@ -579,6 +605,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
(gcs_file->compose ? 0 : -1)});
|
||||
// We are responsible for freeing the pointer returned by TF_GetTempFileName
|
||||
free(temp_file_name);
|
||||
TF_VLog(3, "GcsWritableFile: %s", path);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -608,7 +635,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
} else {
|
||||
// If compose is true, we do not download anything.
|
||||
// Instead we only check if this file exists on server or not.
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object,
|
||||
gcs::Fields("size"));
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
@ -624,7 +652,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TF_VLog(3, "GcsWritableFile: %s with existing file %s", path,
|
||||
temp_file_name.c_str());
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -639,7 +668,8 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object,
|
||||
gcs::Fields("size"));
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
@ -670,7 +700,8 @@ static void StatForObject(GCSFile* gcs_file, const std::string& path,
|
||||
if (object.empty())
|
||||
return TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
("'object' must be a non-empty string. (File: " + path + ")").c_str());
|
||||
absl::StrCat("'object' must be a non-empty string. (File: ", path, ")")
|
||||
.c_str());
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
gcs_file->stat_cache->LookupOrCompute(
|
||||
path, stat,
|
||||
@ -698,7 +729,8 @@ static bool ObjectExists(GCSFile* gcs_file, const std::string& path,
|
||||
|
||||
static bool BucketExists(GCSFile* gcs_file, const std::string& bucket,
|
||||
TF_Status* status) {
|
||||
auto metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
|
||||
auto metadata =
|
||||
gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields(""));
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
|
||||
return false;
|
||||
@ -721,7 +753,8 @@ static std::vector<std::string> GetChildrenBounded(
|
||||
std::string delimiter = recursive ? "" : "/";
|
||||
|
||||
for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes(
|
||||
bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter))) {
|
||||
bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter),
|
||||
gcs::Fields("items(name),prefixes"))) {
|
||||
if (count == max_results) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return result;
|
||||
@ -737,8 +770,8 @@ static std::vector<std::string> GetChildrenBounded(
|
||||
auto pos = children.find(prefix);
|
||||
if (pos != 0) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
("Unexpected response: the returned file name " + children +
|
||||
" doesn't match the prefix " + prefix)
|
||||
absl::StrCat("Unexpected response: the returned file name ",
|
||||
children, " doesn't match the prefix ", prefix)
|
||||
.c_str());
|
||||
return result;
|
||||
}
|
||||
@ -812,6 +845,10 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
std::string dir = path;
|
||||
MaybeAppendSlash(&dir);
|
||||
TF_VLog(3,
|
||||
"CreateDir: creating directory with path: %s and "
|
||||
"path_with_slash: %s",
|
||||
path, dir.c_str());
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(dir, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -821,19 +858,23 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!is_directory)
|
||||
TF_SetStatus(status, TF_NOT_FOUND,
|
||||
("The specified bucket " + dir + " was not found.").c_str());
|
||||
absl::StrCat("The specified bucket ", dir, " was not found.")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
PathExists(filesystem, dir.c_str(), status);
|
||||
if (TF_GetCode(status) == TF_OK)
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
// Use the original name for a correct error here.
|
||||
TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path);
|
||||
return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
|
||||
}
|
||||
|
||||
auto metadata = gcs_file->gcs_client.InsertObject(
|
||||
bucket, object, "",
|
||||
// Adding this parameter means HTTP_CODE_PRECONDITION_FAILED
|
||||
// will be returned if the object already exists, so avoid reuploading.
|
||||
gcs::IfGenerationMatch(0));
|
||||
gcs::IfGenerationMatch(0), gcs::Fields(""));
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
|
||||
TF_SetStatus(status, TF_ALREADY_EXISTS, path);
|
||||
@ -891,7 +932,8 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
|
||||
bucket_src, object_src, bucket_dst, object_dst);
|
||||
bucket_src, object_src, bucket_dst, object_dst,
|
||||
gcs::Fields("done,rewriteToken"));
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
}
|
||||
|
||||
@ -908,7 +950,8 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
|
||||
if (!result)
|
||||
TF_SetStatus(
|
||||
status, TF_NOT_FOUND,
|
||||
("The specified bucket gs://" + bucket + " was not found.").c_str());
|
||||
absl::StrCat("The specified bucket gs://", bucket, " was not found.")
|
||||
.c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -933,6 +976,7 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
|
||||
static void RenameObject(const TF_Filesystem* filesystem,
|
||||
const std::string& src, const std::string& dst,
|
||||
TF_Status* status) {
|
||||
TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str());
|
||||
std::string bucket_src, object_src;
|
||||
ParseGCSPath(src, false, &bucket_src, &object_src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -943,9 +987,11 @@ static void RenameObject(const TF_Filesystem* filesystem,
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
|
||||
bucket_src, object_src, bucket_dst, object_dst);
|
||||
bucket_src, object_src, bucket_dst, object_dst,
|
||||
gcs::Fields("done,rewriteToken"));
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str());
|
||||
|
||||
ClearFileCaches(gcs_file, dst);
|
||||
DeleteFile(filesystem, src.c_str(), status);
|
||||
@ -954,8 +1000,10 @@ static void RenameObject(const TF_Filesystem* filesystem,
|
||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
const char* dst, TF_Status* status) {
|
||||
if (!IsDirectory(filesystem, src, status)) {
|
||||
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
|
||||
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
RenameObject(filesystem, src, dst, status);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1032,7 +1080,8 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
if (object.empty()) {
|
||||
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
|
||||
auto bucket_metadata =
|
||||
gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields(""));
|
||||
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
stats->is_directory = true;
|
||||
@ -1047,8 +1096,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
stats->mtime_nsec = 0;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
|
||||
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) {
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(
|
||||
bucket, object, gcs::Fields("size,timeStorageClassUpdated"));
|
||||
if (metadata) {
|
||||
stats->is_directory = false;
|
||||
stats->length = metadata.value().size();
|
||||
@ -1061,6 +1111,18 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
}
|
||||
}
|
||||
|
||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
// Only validate the name.
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
|
||||
TF_FileStatistics stat;
|
||||
Stat(filesystem, path, &stat, status);
|
||||
return stat.length;
|
||||
}
|
||||
|
||||
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||
return strdup(uri);
|
||||
}
|
||||
|
@ -87,6 +87,24 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
const char* path,
|
||||
TF_ReadOnlyMemoryRegion* region,
|
||||
TF_Status* status);
|
||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
char*** entries, TF_Status* status);
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status);
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
TF_Status* status);
|
||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
const char* dst, TF_Status* status);
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
|
||||
#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
|
||||
|
||||
static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890";
|
||||
// We will work with content_view instead of content.
|
||||
@ -94,6 +95,70 @@ class GCSFilesystemTest : public ::testing::Test {
|
||||
return translated_name;
|
||||
}
|
||||
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
|
||||
GetWriter() {
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
|
||||
new TF_WritableFile, [](TF_WritableFile* file) {
|
||||
if (file != nullptr) {
|
||||
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
writer->plugin_file = nullptr;
|
||||
return writer;
|
||||
}
|
||||
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
|
||||
GetReader() {
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
|
||||
reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
|
||||
if (file != nullptr) {
|
||||
if (file->plugin_file != nullptr)
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
reader->plugin_file = nullptr;
|
||||
return reader;
|
||||
}
|
||||
|
||||
void WriteString(const std::string& path, const std::string& content) {
|
||||
auto writer = GetWriter();
|
||||
tf_gcs_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
|
||||
status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return;
|
||||
tf_writable_file::Append(writer.get(), content.c_str(), content.length(),
|
||||
status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return;
|
||||
tf_writable_file::Close(writer.get(), status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return;
|
||||
}
|
||||
|
||||
std::string ReadAll(const std::string& path) {
|
||||
auto reader = GetReader();
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||
reader.get(), status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return "";
|
||||
|
||||
auto file_size =
|
||||
tf_gcs_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return "";
|
||||
|
||||
std::string content;
|
||||
content.resize(file_size);
|
||||
auto read = tf_random_access_file::Read(reader.get(), 0, file_size,
|
||||
&content[0], status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return "";
|
||||
if (read >= 0) content.resize(read);
|
||||
if (file_size != content.size())
|
||||
TF_SetStatus(
|
||||
status_, TF_DATA_LOSS,
|
||||
std::string("expected " + std::to_string(file_size) + " got " +
|
||||
std::to_string(content.size()) + " bytes")
|
||||
.c_str());
|
||||
return content;
|
||||
}
|
||||
|
||||
protected:
|
||||
TF_Filesystem* filesystem_;
|
||||
TF_Status* status_;
|
||||
@ -326,6 +391,145 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
|
||||
delete region;
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, PathExists) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string path = GetURIForPath("PathExists");
|
||||
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
|
||||
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TF_SetStatus(status_, TF_OK, "");
|
||||
WriteString(path, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, GetChildren) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string base = GetURIForPath("GetChildren");
|
||||
tf_gcs_filesystem::CreateDir(filesystem_, base.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
const std::string file = io::JoinPath(base, "TestFile.csv");
|
||||
WriteString(file, "test");
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
const std::string subdir = io::JoinPath(base, "SubDir");
|
||||
tf_gcs_filesystem::CreateDir(filesystem_, subdir.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv");
|
||||
WriteString(subfile, "test");
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
char** entries;
|
||||
auto num_entries = tf_gcs_filesystem::GetChildren(filesystem_, base.c_str(),
|
||||
&entries, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
std::vector<std::string> childrens;
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
childrens.push_back(entries[i]);
|
||||
}
|
||||
std::sort(childrens.begin(), childrens.end());
|
||||
EXPECT_EQ(std::vector<string>({"SubDir/", "TestFile.csv"}), childrens);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, DeleteFile) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string path = GetURIForPath("DeleteFile");
|
||||
WriteString(path, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
tf_gcs_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
|
||||
EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, CreateDir) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string dir = GetURIForPath("CreateDir");
|
||||
tf_gcs_filesystem::CreateDir(filesystem_, dir.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
TF_FileStatistics stat;
|
||||
tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_TRUE(stat.is_directory);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, DeleteDir) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string dir = GetURIForPath("DeleteDir");
|
||||
const std::string file = io::JoinPath(dir, "DeleteDirFile.csv");
|
||||
WriteString(file, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
|
||||
EXPECT_EQ(TF_GetCode(status_), TF_FAILED_PRECONDITION);
|
||||
|
||||
TF_SetStatus(status_, TF_OK, "");
|
||||
tf_gcs_filesystem::DeleteFile(filesystem_, file.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
TF_FileStatistics stat;
|
||||
tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
|
||||
EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, StatFile) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string path = GetURIForPath("StatFile");
|
||||
WriteString(path, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
TF_FileStatistics stat;
|
||||
tf_gcs_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(4, stat.length);
|
||||
EXPECT_FALSE(stat.is_directory);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, RenameFile) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string src = GetURIForPath("RenameFileSrc");
|
||||
const std::string dst = GetURIForPath("RenameFileDst");
|
||||
WriteString(src, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
auto result = ReadAll(dst);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ("test", result);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, RenameFileOverwrite) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
const std::string src = GetURIForPath("RenameFileOverwriteSrc");
|
||||
const std::string dst = GetURIForPath("RenameFileOverwriteDst");
|
||||
|
||||
WriteString(src, "test_old");
|
||||
ASSERT_TF_OK(status_);
|
||||
WriteString(dst, "test_new");
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
tf_gcs_filesystem::PathExists(filesystem_, dst.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
auto result = ReadAll(dst);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ("test_old", result);
|
||||
}
|
||||
|
||||
// These tests below are ported from
|
||||
// `//tensorflow/core/platform/cloud:gcs_file_system_test`
|
||||
TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) {
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
@ -65,8 +66,8 @@ class RamFileBlockCache {
|
||||
pruning_thread_.reset(
|
||||
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
|
||||
}
|
||||
std::cout << "GCS file block cache is "
|
||||
<< (IsCacheEnabled() ? "enabled" : "disabled") << ".\n";
|
||||
TF_VLog(1, "GCS file block cache is %s.\n",
|
||||
(IsCacheEnabled() ? "enabled" : "disabled"));
|
||||
}
|
||||
|
||||
~RamFileBlockCache() {
|
||||
|
@ -1,5 +1,7 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Experimental hadoop filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -20,12 +22,14 @@ cc_library(
|
||||
name = "hadoop_filesystem_impl",
|
||||
srcs = ["hadoop_filesystem.cc"],
|
||||
hdrs = ["hadoop_filesystem.h"],
|
||||
compatible_with = [],
|
||||
copts = select({
|
||||
"//conditions:default": [],
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//third_party/hadoop:hdfs",
|
||||
@ -33,3 +37,38 @@ cc_library(
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
# This test is set to manual because it requires downloading the Hadoop
|
||||
# distribution to run. To run this test:
|
||||
# 1. Ensure $JAVA_HOME is set to the location of a JDK 8 installation.
|
||||
# 2. Download the binary Hadoop distribution from:
|
||||
# http://hadoop.apache.org/releases.html
|
||||
# 3. Extract the Hadoop distribution and run:
|
||||
# source libexec/hadoop-config.sh
|
||||
# 4. Optionally set up HDFS cluster configurations (optionally Kerberos) within
|
||||
# $HADOOP_HDFS_HOME/etc/hadoop if you want to test against real
|
||||
# distributed HDFS cluster
|
||||
# 5. bazel test \
|
||||
# --test_env=LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/amd64/server \
|
||||
# --test_env=HADOOP_HDFS_HOME=$HADOOP_HDFS_HOME \
|
||||
# --test_env=CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) \
|
||||
# :hadoop_file_system_test
|
||||
# To test against the real distributed cluster, add the following option for
|
||||
# bazel test:
|
||||
# --test_env=HADOOP_TEST_TMPDIR=hdfs://cluster/test/tmp/dir
|
||||
tf_cc_test(
|
||||
name = "hadoop_filesystem_test",
|
||||
srcs = [
|
||||
"hadoop_filesystem_test.cc",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":hadoop_filesystem_impl",
|
||||
"//tensorflow/core/platform:path",
|
||||
"//tensorflow/core/platform:stacktrace_handler",
|
||||
"//tensorflow/core/platform:test",
|
||||
],
|
||||
)
|
||||
|
@ -22,11 +22,10 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "third_party/hadoop/hdfs.h"
|
||||
|
||||
// Implementation of a filesystem for HADOOP environments.
|
||||
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
|
||||
@ -37,11 +36,17 @@ static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
void ParseHadoopPath(const std::string& fname, std::string* scheme,
|
||||
std::string* namenode, std::string* path) {
|
||||
size_t scheme_end = fname.find("://") + 2;
|
||||
*scheme = fname.substr(0, scheme_end + 1);
|
||||
// We don't want `://` in scheme.
|
||||
*scheme = fname.substr(0, scheme_end - 2);
|
||||
size_t nn_end = fname.find("/", scheme_end + 1);
|
||||
if (nn_end == std::string::npos) return;
|
||||
if (nn_end == std::string::npos) {
|
||||
*namenode = fname.substr(scheme_end + 1);
|
||||
*path = "";
|
||||
return;
|
||||
}
|
||||
*namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1);
|
||||
*path = fname.substr(nn_end + 1);
|
||||
// We keep `/` in path.
|
||||
*path = fname.substr(nn_end);
|
||||
}
|
||||
|
||||
void SplitArchiveNameAndPath(std::string* path, std::string* nn,
|
||||
@ -54,7 +59,7 @@ void SplitArchiveNameAndPath(std::string* path, std::string* nn,
|
||||
}
|
||||
// Case of hadoop archive. Namenode is the path to the archive.
|
||||
std::ostringstream namenodestream;
|
||||
namenodestream << "har://" << nn
|
||||
namenodestream << "har://" << *nn
|
||||
<< path->substr(0, index_end_archive_name + 4);
|
||||
*nn = namenodestream.str();
|
||||
path->erase(0, index_end_archive_name + 4);
|
||||
@ -143,15 +148,20 @@ class LibHDFS {
|
||||
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
||||
if (hdfs_home != nullptr) {
|
||||
auto JoinPath = [](std::string home, std::string lib) {
|
||||
#if defined(_WIN32)
|
||||
if (home.back() != '\\') home.push_back('\\');
|
||||
return home + "lib\\native\\" + lib;
|
||||
#else
|
||||
if (home.back() != '/') home.push_back('/');
|
||||
return home + "lib/native/" + lib;
|
||||
#endif
|
||||
};
|
||||
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
||||
TryLoadAndBind(path.c_str(), &handle_, status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
return;
|
||||
} else {
|
||||
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
|
||||
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,13 +173,15 @@ class LibHDFS {
|
||||
void* handle_;
|
||||
};
|
||||
|
||||
// We rely on HDFS connection caching here. The HDFS client calls
|
||||
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
|
||||
// internally.
|
||||
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
||||
// We implement connection caching in Tensorflow, which can significantly
|
||||
// improve performance. Fixes #43187
|
||||
hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
||||
const std::string& path, TF_Status* status) {
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
||||
|
||||
std::string cacheKey(scheme);
|
||||
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
||||
if (scheme == "file") {
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
|
||||
@ -194,15 +206,24 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
||||
SplitArchiveNameAndPath(&path_har, &namenode, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
|
||||
cacheKey += namenode;
|
||||
} else {
|
||||
libhdfs->hdfsBuilderSetNameNode(
|
||||
builder, namenode.empty() ? "default" : namenode.c_str());
|
||||
cacheKey += namenode;
|
||||
}
|
||||
auto fs = libhdfs->hdfsBuilderConnect(builder);
|
||||
if (fs == nullptr)
|
||||
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
absl::MutexLock l(&hadoop_file->connection_cache_lock);
|
||||
if (hadoop_file->connection_cache.find(cacheKey) ==
|
||||
hadoop_file->connection_cache.end()) {
|
||||
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
||||
if (cacheFs == nullptr) {
|
||||
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
||||
return cacheFs;
|
||||
}
|
||||
hadoop_file->connection_cache[cacheKey] = cacheFs;
|
||||
}
|
||||
auto fs = hadoop_file->connection_cache[cacheKey];
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return fs;
|
||||
}
|
||||
|
||||
@ -216,6 +237,7 @@ typedef struct HDFSFile {
|
||||
LibHDFS* libhdfs;
|
||||
absl::Mutex mu;
|
||||
hdfsFile handle ABSL_GUARDED_BY(mu);
|
||||
bool disable_eof_retried;
|
||||
HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs,
|
||||
hdfsFile handle)
|
||||
: path(std::move(path)),
|
||||
@ -223,7 +245,15 @@ typedef struct HDFSFile {
|
||||
fs(fs),
|
||||
libhdfs(libhdfs),
|
||||
mu(),
|
||||
handle(handle) {}
|
||||
handle(handle) {
|
||||
const char* disable_eof_retried_str =
|
||||
getenv("HDFS_DISABLE_READ_EOF_RETRIED");
|
||||
if (disable_eof_retried_str && disable_eof_retried_str[0] == '1') {
|
||||
disable_eof_retried = true;
|
||||
} else {
|
||||
disable_eof_retried = false;
|
||||
}
|
||||
}
|
||||
} HDFSFile;
|
||||
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
@ -247,8 +277,12 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
|
||||
char* dst = buffer;
|
||||
bool eof_retried = false;
|
||||
int64_t r = 0;
|
||||
while (TF_GetCode(status) == TF_OK && !eof_retried) {
|
||||
if (hdfs_file->disable_eof_retried) {
|
||||
// eof_retried = true, avoid calling hdfsOpenFile in Read, Fixes #42597
|
||||
eof_retried = true;
|
||||
}
|
||||
int64_t read = 0;
|
||||
while (TF_GetCode(status) == TF_OK && n > 0) {
|
||||
// We lock inside the loop rather than outside so we don't block other
|
||||
// concurrent readers.
|
||||
absl::MutexLock l(&hdfs_file->mu);
|
||||
@ -257,12 +291,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
// of int32. -2 offset can avoid JVM OutOfMemoryError.
|
||||
size_t read_n =
|
||||
(std::min)(n, static_cast<size_t>(std::numeric_limits<int>::max() - 2));
|
||||
r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset), dst,
|
||||
static_cast<tSize>(read_n));
|
||||
int64_t r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset),
|
||||
dst, static_cast<tSize>(read_n));
|
||||
if (r > 0) {
|
||||
dst += r;
|
||||
n -= r;
|
||||
offset += r;
|
||||
read += r;
|
||||
} else if (!eof_retried && r == 0) {
|
||||
// Always reopen the file upon reaching EOF to see if there's more data.
|
||||
// If writers are streaming contents while others are concurrently
|
||||
@ -274,11 +309,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
TF_SetStatusFromIOError(status, errno, path);
|
||||
return -1;
|
||||
}
|
||||
handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0);
|
||||
if (handle == nullptr) {
|
||||
hdfs_file->handle =
|
||||
libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0);
|
||||
if (hdfs_file->handle == nullptr) {
|
||||
TF_SetStatusFromIOError(status, errno, path);
|
||||
return -1;
|
||||
}
|
||||
handle = hdfs_file->handle;
|
||||
eof_retried = true;
|
||||
} else if (eof_retried && r == 0) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
@ -288,7 +325,7 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
TF_SetStatusFromIOError(status, errno, path);
|
||||
}
|
||||
}
|
||||
return r;
|
||||
return read;
|
||||
}
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
@ -308,7 +345,7 @@ typedef struct HDFSFile {
|
||||
handle(handle) {}
|
||||
} HDFSFile;
|
||||
|
||||
static void Cleanup(TF_WritableFile* file) {
|
||||
void Cleanup(TF_WritableFile* file) {
|
||||
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
|
||||
hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
|
||||
hdfs_file->fs = nullptr;
|
||||
@ -387,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
// Hadoop doesn't support Readonly Memory Region
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_hadoop_filesystem {
|
||||
|
||||
HadoopFile::HadoopFile(TF_Status* status)
|
||||
: libhdfs(new LibHDFS(status)),
|
||||
connection_cache_lock(),
|
||||
connection_cache() {}
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
filesystem->plugin_filesystem = new LibHDFS(status);
|
||||
filesystem->plugin_filesystem = new HadoopFile(status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Cleanup(TF_Filesystem* filesystem) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
delete libhdfs;
|
||||
delete hadoop_file;
|
||||
}
|
||||
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -426,8 +469,27 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
||||
|
||||
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_WRONLY, 0, 0, 0);
|
||||
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
|
||||
|
||||
file->plugin_file =
|
||||
new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -458,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
|
||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -474,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -493,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -514,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -529,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -544,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -580,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
const char* dst, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, src, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
|
||||
@ -601,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
|
||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
char*** entries, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -638,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
return num_entries;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||
return strdup(uri);
|
||||
}
|
||||
|
||||
} // namespace tf_hadoop_filesystem
|
||||
|
||||
@ -646,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_hadoop_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file =
|
||||
tf_hadoop_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_hadoop_filesystem::NewAppendableFile;
|
||||
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
|
||||
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
|
||||
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
|
||||
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
|
||||
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
|
||||
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
|
||||
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
|
||||
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
|
||||
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
|
@ -15,7 +15,73 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "third_party/hadoop/hdfs.h"
|
||||
|
||||
void ParseHadoopPath(const std::string& fname, std::string* scheme,
|
||||
std::string* namenode, std::string* path);
|
||||
void SplitArchiveNameAndPath(std::string* path, std::string* nn,
|
||||
TF_Status* status);
|
||||
class LibHDFS;
|
||||
|
||||
namespace tf_random_access_file {
|
||||
void Cleanup(TF_RandomAccessFile* file);
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status);
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
namespace tf_writable_file {
|
||||
void Cleanup(TF_WritableFile* file);
|
||||
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
|
||||
TF_Status* status);
|
||||
int64_t Tell(const TF_WritableFile* file, TF_Status* status);
|
||||
void Sync(const TF_WritableFile* file, TF_Status* status);
|
||||
void Flush(const TF_WritableFile* file, TF_Status* status);
|
||||
void Close(const TF_WritableFile* file, TF_Status* status);
|
||||
} // namespace tf_writable_file
|
||||
|
||||
namespace tf_hadoop_filesystem {
|
||||
typedef struct HadoopFile {
|
||||
LibHDFS* libhdfs;
|
||||
absl::Mutex connection_cache_lock;
|
||||
std::map<std::string, hdfsFS> connection_cache
|
||||
ABSL_GUARDED_BY(connection_cache_lock);
|
||||
HadoopFile(TF_Status* status);
|
||||
} HadoopFile;
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status);
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
const char* path,
|
||||
TF_ReadOnlyMemoryRegion* region,
|
||||
TF_Status* status);
|
||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status);
|
||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status);
|
||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
const char* dst, TF_Status* status);
|
||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
char*** entries, TF_Status* status);
|
||||
} // namespace tf_hadoop_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||
|
@ -0,0 +1,460 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h"
|
||||
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/stacktrace_handler.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "third_party/hadoop/hdfs.h"
|
||||
|
||||
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
|
||||
#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class HadoopFileSystemTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
status_ = TF_NewStatus();
|
||||
filesystem_ = new TF_Filesystem;
|
||||
tf_hadoop_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
}
|
||||
void TearDown() override {
|
||||
TF_DeleteStatus(status_);
|
||||
tf_hadoop_filesystem::Cleanup(filesystem_);
|
||||
delete filesystem_;
|
||||
}
|
||||
|
||||
std::string TmpDir(const std::string& path) {
|
||||
char* test_dir = getenv("HADOOP_TEST_TMPDIR");
|
||||
if (test_dir != nullptr) {
|
||||
return io::JoinPath(std::string(test_dir), path);
|
||||
} else {
|
||||
return "file://" + io::JoinPath(testing::TmpDir(), path);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
|
||||
GetWriter() {
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
|
||||
new TF_WritableFile, [](TF_WritableFile* file) {
|
||||
if (file != nullptr) {
|
||||
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
writer->plugin_file = nullptr;
|
||||
return writer;
|
||||
}
|
||||
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
|
||||
GetReader() {
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
|
||||
reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
|
||||
if (file != nullptr) {
|
||||
if (file->plugin_file != nullptr)
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
reader->plugin_file = nullptr;
|
||||
return reader;
|
||||
}
|
||||
|
||||
void WriteString(const std::string& path, const std::string& content) {
|
||||
auto writer = GetWriter();
|
||||
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(),
|
||||
writer.get(), status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return;
|
||||
tf_writable_file::Append(writer.get(), content.c_str(), content.length(),
|
||||
status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return;
|
||||
tf_writable_file::Close(writer.get(), status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return;
|
||||
}
|
||||
|
||||
std::string ReadAll(const std::string& path) {
|
||||
auto reader = GetReader();
|
||||
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||
reader.get(), status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return "";
|
||||
|
||||
auto file_size =
|
||||
tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return "";
|
||||
|
||||
std::string content;
|
||||
content.resize(file_size);
|
||||
auto read = tf_random_access_file::Read(reader.get(), 0, file_size,
|
||||
&content[0], status_);
|
||||
if (TF_GetCode(status_) != TF_OK) return "";
|
||||
if (read >= 0) content.resize(read);
|
||||
if (file_size != content.size())
|
||||
TF_SetStatus(
|
||||
status_, TF_DATA_LOSS,
|
||||
std::string("expected " + std::to_string(file_size) + " got " +
|
||||
std::to_string(content.size()) + " bytes")
|
||||
.c_str());
|
||||
return content;
|
||||
}
|
||||
|
||||
protected:
|
||||
TF_Filesystem* filesystem_;
|
||||
TF_Status* status_;
|
||||
};
|
||||
|
||||
TEST_F(HadoopFileSystemTest, RandomAccessFile) {
|
||||
const std::string path = TmpDir("RandomAccessFile");
|
||||
const std::string content = "abcdefghijklmn";
|
||||
|
||||
WriteString(path, content);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
auto reader = GetReader();
|
||||
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||
reader.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(content.size());
|
||||
auto read = tf_random_access_file::Read(reader.get(), 0, content.size(),
|
||||
&result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(content.size(), result.size());
|
||||
EXPECT_EQ(content, result);
|
||||
|
||||
result.clear();
|
||||
result.resize(4);
|
||||
read = tf_random_access_file::Read(reader.get(), 2, 4, &result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(4, result.size());
|
||||
EXPECT_EQ(content.substr(2, 4), result);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, WritableFile) {
|
||||
auto writer = GetWriter();
|
||||
const std::string path = TmpDir("WritableFile");
|
||||
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_writable_file::Append(writer.get(), "content1,", strlen("content1,"),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
auto pos = tf_writable_file::Tell(writer.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(pos, 9);
|
||||
|
||||
tf_writable_file::Append(writer.get(), "content2", strlen("content2"),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_writable_file::Flush(writer.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_writable_file::Sync(writer.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_writable_file::Close(writer.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
auto content = ReadAll(path);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ("content1,content2", content);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, PathExists) {
|
||||
const std::string path = TmpDir("PathExists");
|
||||
tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_);
|
||||
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TF_SetStatus(status_, TF_OK, "");
|
||||
WriteString(path, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, GetChildren) {
|
||||
const std::string base = TmpDir("GetChildren");
|
||||
tf_hadoop_filesystem::CreateDir(filesystem_, base.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
const std::string file = io::JoinPath(base, "TestFile.csv");
|
||||
WriteString(file, "test");
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
const std::string subdir = io::JoinPath(base, "SubDir");
|
||||
tf_hadoop_filesystem::CreateDir(filesystem_, subdir.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv");
|
||||
WriteString(subfile, "test");
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
char** entries;
|
||||
auto num_entries = tf_hadoop_filesystem::GetChildren(
|
||||
filesystem_, base.c_str(), &entries, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
std::vector<std::string> childrens;
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
childrens.push_back(entries[i]);
|
||||
}
|
||||
std::sort(childrens.begin(), childrens.end());
|
||||
EXPECT_EQ(std::vector<string>({"SubDir", "TestFile.csv"}), childrens);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, DeleteFile) {
|
||||
const std::string path = TmpDir("DeleteFile");
|
||||
WriteString(path, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, GetFileSize) {
|
||||
const std::string path = TmpDir("GetFileSize");
|
||||
WriteString(path, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
auto file_size =
|
||||
tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(4, file_size);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, CreateDirStat) {
|
||||
const std::string path = TmpDir("CreateDirStat");
|
||||
tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
TF_FileStatistics stat;
|
||||
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_TRUE(stat.is_directory);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, DeleteDir) {
|
||||
const std::string path = TmpDir("DeleteDir");
|
||||
tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_);
|
||||
EXPECT_NE(TF_GetCode(status_), TF_OK);
|
||||
tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
TF_FileStatistics stat;
|
||||
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
|
||||
EXPECT_NE(TF_GetCode(status_), TF_OK);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, RenameFile) {
|
||||
const std::string src = TmpDir("RenameFileSrc");
|
||||
const std::string dst = TmpDir("RenameFileDst");
|
||||
WriteString(src, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
auto result = ReadAll(dst);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ("test", result);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, RenameFileOverwrite) {
|
||||
const std::string src = TmpDir("RenameFileOverwriteSrc");
|
||||
const std::string dst = TmpDir("RenameFileOverwriteDst");
|
||||
|
||||
WriteString(src, "test_old");
|
||||
ASSERT_TF_OK(status_);
|
||||
WriteString(dst, "test_new");
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
tf_hadoop_filesystem::PathExists(filesystem_, dst.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
auto result = ReadAll(dst);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ("test_old", result);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, StatFile) {
|
||||
const std::string path = TmpDir("StatFile");
|
||||
WriteString(path, "test");
|
||||
ASSERT_TF_OK(status_);
|
||||
TF_FileStatistics stat;
|
||||
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(4, stat.length);
|
||||
EXPECT_FALSE(stat.is_directory);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, WriteWhileReading) {
|
||||
const std::string path = TmpDir("WriteWhileReading");
|
||||
// Skip the test if we're not testing on HDFS. Hadoop's local filesystem
|
||||
// implementation makes no guarantees that writable files are readable while
|
||||
// being written.
|
||||
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
|
||||
|
||||
auto writer = GetWriter();
|
||||
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
const std::string content1 = "content1";
|
||||
tf_writable_file::Append(writer.get(), content1.c_str(), content1.size(),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_writable_file::Flush(writer.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
auto reader = GetReader();
|
||||
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||
reader.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(content1.size());
|
||||
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
|
||||
&result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(content1, result);
|
||||
|
||||
const std::string content2 = "content2";
|
||||
tf_writable_file::Append(writer.get(), content2.c_str(), content2.size(),
|
||||
status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
tf_writable_file::Flush(writer.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
result.resize(content2.size());
|
||||
read = tf_random_access_file::Read(reader.get(), content1.size(),
|
||||
content2.size(), &result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(content2, result);
|
||||
|
||||
tf_writable_file::Close(writer.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) {
|
||||
static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1";
|
||||
putenv(set_disable_var);
|
||||
|
||||
const std::string path = TmpDir("ReadWhileOverwriting");
|
||||
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
|
||||
|
||||
const string content1 = "content1";
|
||||
WriteString(path, content1);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
auto reader = GetReader();
|
||||
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||
reader.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(content1.size());
|
||||
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
|
||||
&result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(content1, result);
|
||||
|
||||
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
string content2 = "overwrite";
|
||||
WriteString(path, content1 + content2);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
result.resize(content2.size());
|
||||
read = tf_random_access_file::Read(reader.get(), content1.size(),
|
||||
content2.size(), &result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(0, result.size());
|
||||
|
||||
static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0";
|
||||
putenv(set_enable_var);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, HarSplit) {
|
||||
const std::string har_path =
|
||||
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
|
||||
std::string scheme, namenode, path;
|
||||
ParseHadoopPath(har_path, &scheme, &namenode, &path);
|
||||
EXPECT_EQ("har", scheme);
|
||||
EXPECT_EQ("hdfs-root", namenode);
|
||||
EXPECT_EQ("/user/j.doe/my_archive.har/dir0/dir1/file.txt", path);
|
||||
SplitArchiveNameAndPath(&path, &namenode, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode);
|
||||
EXPECT_EQ("/dir0/dir1/file.txt", path);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, NoHarExtension) {
|
||||
const std::string har_path =
|
||||
"har://hdfs-root/user/j.doe/my_archive/dir0/dir1/file.txt";
|
||||
std::string scheme, namenode, path;
|
||||
ParseHadoopPath(har_path, &scheme, &namenode, &path);
|
||||
EXPECT_EQ("har", scheme);
|
||||
EXPECT_EQ("hdfs-root", namenode);
|
||||
EXPECT_EQ("/user/j.doe/my_archive/dir0/dir1/file.txt", path);
|
||||
SplitArchiveNameAndPath(&path, &namenode, status_);
|
||||
EXPECT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT) << TF_Message(status_);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, HarRootPath) {
|
||||
const std::string har_path = "har://hdfs-root/user/j.doe/my_archive.har";
|
||||
std::string scheme, namenode, path;
|
||||
ParseHadoopPath(har_path, &scheme, &namenode, &path);
|
||||
EXPECT_EQ("har", scheme);
|
||||
EXPECT_EQ("hdfs-root", namenode);
|
||||
EXPECT_EQ("/user/j.doe/my_archive.har", path);
|
||||
SplitArchiveNameAndPath(&path, &namenode, status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode);
|
||||
EXPECT_EQ("/", path);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, WriteLargeFile) {
|
||||
if (std::getenv("HADOOP_TEST_LARGE_FILE") != "1") GTEST_SKIP();
|
||||
const std::string path = TmpDir("WriteLargeFile");
|
||||
const size_t file_size =
|
||||
static_cast<size_t>(std::numeric_limits<tSize>::max()) + 1024;
|
||||
// Fake a test string.
|
||||
std::string source(file_size, {});
|
||||
for (size_t i = 0; i < file_size; ++i) source[i] = (i % 128);
|
||||
WriteString(path, source);
|
||||
ASSERT_TF_OK(status_);
|
||||
auto result = ReadAll(path);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(source, result);
|
||||
}
|
||||
// NewAppendableFile() is not testable. Local filesystem maps to
|
||||
// ChecksumFileSystem in Hadoop, where appending is an unsupported operation.
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
tensorflow::testing::InstallStacktraceHandler();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Experimental posix filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Experimental windows filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
|
||||
|
@ -1,3 +1,6 @@
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Library of gradient functions.
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -16,7 +19,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
@ -31,14 +34,11 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
@ -52,13 +52,46 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
hdrs = [
|
||||
"array_grad.h",
|
||||
"math_grad.h",
|
||||
"nn_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":array_grad",
|
||||
":math_grad",
|
||||
":nn_grad",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"array_grad.h",
|
||||
"math_grad.h",
|
||||
"nn_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -22,10 +22,10 @@ limitations under the License.
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Identity;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::ZerosLike;
|
||||
using tensorflow::ops::Neg;
|
||||
using tensorflow::ops::SqrtGrad;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -36,21 +36,14 @@ class AddGradientFunction : public GradientFunction {
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// TODO(b/145674566): Handle name unification in tracing code.
|
||||
// TODO(b/161805092): Support broadcasting.
|
||||
|
||||
std::string name = "Identity_A";
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
name.c_str()));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
(*grad_outputs)[1] = grad_inputs[0];
|
||||
|
||||
name = "Identity_B";
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
name.c_str()));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
(*grad_outputs)[0]->Ref();
|
||||
(*grad_outputs)[1]->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
@ -81,6 +74,25 @@ class ExpGradientFunction : public GradientFunction {
|
||||
AbstractTensorHandlePtr exp_;
|
||||
};
|
||||
|
||||
class SqrtGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||
sqrt->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
std::string name = "Sqrt_Grad";
|
||||
grad_outputs->resize(1);
|
||||
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~SqrtGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractTensorHandlePtr sqrt_;
|
||||
};
|
||||
|
||||
class MatMulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
@ -190,6 +202,56 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
AttrBuilder forward_attrs;
|
||||
};
|
||||
|
||||
class NegGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a Neg op Y = -X, the gradients are:
|
||||
*
|
||||
* dX = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(1);
|
||||
std::string name = "Neg_Grad";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~NegGradientFunction() override {}
|
||||
};
|
||||
|
||||
class SubGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a Sub op A-B, the gradients are:
|
||||
*
|
||||
* dA = U
|
||||
* dB = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Grad for A
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
(*grad_outputs)[0]->Ref();
|
||||
|
||||
// Grad for B
|
||||
// negate the upstream grad
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
std::string name = "Neg_Sub_Grad_B";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(neg_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = neg_outputs[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~SubGradientFunction() override {}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
@ -219,5 +281,32 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new NegGradientFunction;
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new SubGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -19,9 +19,14 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -14,17 +14,19 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Identity;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::ReluGrad;
|
||||
using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
|
||||
using tensorflow::ops::ZerosLike;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -58,9 +60,31 @@ class ReluGradientFunction : public GradientFunction {
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
};
|
||||
|
||||
class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec,
|
||||
AbstractTensorHandle* mat,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
if (!isa<ImmediateExecutionContext>(ctx)) {
|
||||
// TODO(b/168850692): Fix this.
|
||||
return errors::Unimplemented(
|
||||
"BroadcastMul is not supported in tracing mode yet.");
|
||||
}
|
||||
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(ctx);
|
||||
AbstractTensorPtr minus_1(imm_ctx->CreateInt32Scalar(-1));
|
||||
ImmediateTensorHandlePtr dim(imm_ctx->CreateLocalHandle(minus_1.get()));
|
||||
vector<AbstractTensorHandle*> expand_dims_outputs(1);
|
||||
TF_RETURN_IF_ERROR(ops::ExpandDims(ctx, {vec, dim.get()},
|
||||
absl::MakeSpan(expand_dims_outputs),
|
||||
"ExpandDims"));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Mul(ctx, {expand_dims_outputs[0], mat}, outputs, "Mul"));
|
||||
expand_dims_outputs[0]->Unref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
|
||||
: public GradientFunction {
|
||||
public:
|
||||
explicit SparseSoftmaxCrossEntropyLossGradientFunction(
|
||||
explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction(
|
||||
vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_outputs(f_outputs) {}
|
||||
|
||||
@ -69,12 +93,10 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Grad for Softmax Input
|
||||
std::string name = "Mul_Softmax_Grad";
|
||||
vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Mul(ctx->ctx, {grad_inputs[0], forward_outputs[1]},
|
||||
absl::MakeSpan(mul_outputs),
|
||||
name.c_str())); // upstream_grad * local softmax grad
|
||||
TF_RETURN_IF_ERROR(BroadcastMul(
|
||||
ctx->ctx, grad_inputs[0], forward_outputs[1],
|
||||
absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad
|
||||
(*grad_outputs)[0] = mul_outputs[0];
|
||||
|
||||
// Grad for labels is null
|
||||
@ -82,7 +104,7 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~SparseSoftmaxCrossEntropyLossGradientFunction() override {}
|
||||
~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
@ -99,10 +121,10 @@ BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
|
||||
const ForwardOperation& op) {
|
||||
auto gradient_function =
|
||||
new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs);
|
||||
new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
|
||||
const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
66
tensorflow/c/experimental/gradients/tape/BUILD
Normal file
66
tensorflow/c/experimental/gradients/tape/BUILD
Normal file
@ -0,0 +1,66 @@
|
||||
# A tape built on top of unified execution APIs.
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape_context",
|
||||
srcs = ["tape_context.cc"],
|
||||
hdrs = [
|
||||
"tape_context.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tape_operation",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape_operation",
|
||||
srcs = ["tape_operation.cc"],
|
||||
hdrs = [
|
||||
"tape_operation.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape",
|
||||
hdrs = [
|
||||
"tape_context.h",
|
||||
"tape_operation.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tape_context",
|
||||
":tape_operation",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"tape_context.h",
|
||||
"tape_operation.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
47
tensorflow/c/experimental/gradients/tape/tape_context.cc
Normal file
47
tensorflow/c/experimental/gradients/tape/tape_context.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
TapeContext::TapeContext(AbstractContext* c, Tape* tape,
|
||||
const GradientRegistry& registry)
|
||||
: AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) {
|
||||
// TODO(srbs): Make AbstractContext ref counted.
|
||||
// parent_ctx_->Ref();
|
||||
}
|
||||
void TapeContext::Release() {
|
||||
// TODO(srbs): Change to Unref()
|
||||
delete this;
|
||||
}
|
||||
TapeContext::~TapeContext() {
|
||||
// TODO(srbs): Make AbstractContext ref counted.
|
||||
// parent_ctx_->Unref();
|
||||
}
|
||||
TapeOperation* TapeContext::CreateOperation() {
|
||||
return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_);
|
||||
}
|
||||
Status TapeContext::RegisterFunction(AbstractFunction* f) {
|
||||
return parent_ctx_->RegisterFunction(f);
|
||||
}
|
||||
Status TapeContext::RemoveFunction(const string& func) {
|
||||
return parent_ctx_->RemoveFunction(func);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
44
tensorflow/c/experimental/gradients/tape/tape_context.h
Normal file
44
tensorflow/c/experimental/gradients/tape/tape_context.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
class TapeContext : public AbstractContext {
|
||||
public:
|
||||
explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&);
|
||||
void Release() override;
|
||||
TapeOperation* CreateOperation() override;
|
||||
Status RegisterFunction(AbstractFunction*) override;
|
||||
Status RemoveFunction(const string& func) override;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kTape;
|
||||
}
|
||||
~TapeContext() override;
|
||||
|
||||
private:
|
||||
AbstractContext* parent_ctx_; // Not owned.
|
||||
Tape* tape_;
|
||||
const GradientRegistry& registry_;
|
||||
};
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
238
tensorflow/c/experimental/gradients/tape/tape_operation.cc
Normal file
238
tensorflow/c/experimental/gradients/tape/tape_operation.cc
Normal file
@ -0,0 +1,238 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
|
||||
const GradientRegistry& registry)
|
||||
: AbstractOperation(kTape),
|
||||
parent_op_(parent_op),
|
||||
tape_(tape),
|
||||
registry_(registry) {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// parent_op_->Ref();
|
||||
}
|
||||
void TapeOperation::Release() {
|
||||
// TODO(srbs): Change to Unref().
|
||||
delete this;
|
||||
}
|
||||
TapeOperation::~TapeOperation() {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// parent_op->Unref();
|
||||
}
|
||||
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
|
||||
forward_op_.op_name = op;
|
||||
forward_op_.attrs.Reset(op);
|
||||
forward_op_.inputs.clear();
|
||||
forward_op_.outputs.clear();
|
||||
return parent_op_->Reset(op, raw_device_name);
|
||||
}
|
||||
const string& TapeOperation::Name() const { return parent_op_->Name(); }
|
||||
const string& TapeOperation::DeviceName() const {
|
||||
return parent_op_->DeviceName();
|
||||
}
|
||||
Status TapeOperation::SetDeviceName(const char* name) {
|
||||
return parent_op_->SetDeviceName(name);
|
||||
}
|
||||
Status TapeOperation::AddInput(AbstractTensorHandle* input) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
|
||||
forward_op_.inputs.push_back(input);
|
||||
return Status::OK();
|
||||
}
|
||||
Status TapeOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs));
|
||||
for (auto input : inputs) {
|
||||
forward_op_.inputs.push_back(input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status TapeOperation::SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) {
|
||||
forward_op_.attrs.Set(attr_name, StringPiece(data, length));
|
||||
return parent_op_->SetAttrString(attr_name, data, length);
|
||||
}
|
||||
Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
forward_op_.attrs.Set(attr_name, static_cast<int64>(value));
|
||||
return parent_op_->SetAttrInt(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrFloat(const char* attr_name, float value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrFloat(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrBool(const char* attr_name, bool value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrBool(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrType(const char* attr_name, DataType value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrType(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
forward_op_.attrs.Set(attr_name, proto);
|
||||
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
Status TapeOperation::SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunction has not been implemented yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrFunctionName(const char* attr_name,
|
||||
const char* value, size_t length) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented "
|
||||
"yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTensor has not been implemented yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths, int num_values) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
forward_op_.attrs.Set(attr_name, v);
|
||||
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrFloatList(const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const float>(values, num_values));
|
||||
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
forward_op_.attrs.Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return parent_op_->SetAttrIntList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrTypeList(const char* attr_name,
|
||||
const DataType* values, int num_values) {
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const DataType>(values, num_values));
|
||||
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims, int num_values) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
forward_op_.attrs.Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrFunctionList(
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionList has not been "
|
||||
"implemented yet.");
|
||||
}
|
||||
AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
|
||||
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
|
||||
std::vector<int64> input_ids(forward_op_.inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(forward_op_.inputs.size());
|
||||
for (int i = 0; i < forward_op_.inputs.size(); i++) {
|
||||
input_ids[i] = ToId(forward_op_.inputs[i]);
|
||||
input_dtypes[i] = forward_op_.inputs[i]->DataType();
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; i++) {
|
||||
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
|
||||
forward_op_.outputs.push_back(retvals[i]);
|
||||
}
|
||||
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
|
||||
// attributes. Number type attrs and DataType attrs work fine without this.
|
||||
// Consider getting rid of this and making the behavior between number types
|
||||
// and string consistent.
|
||||
forward_op_.attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
tape_->RecordOperation(
|
||||
parent_op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[this]() -> BackwardFunction* {
|
||||
std::unique_ptr<BackwardFunction> backward_fn;
|
||||
Status s = registry_.Lookup(forward_op_, &backward_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return backward_fn.release();
|
||||
},
|
||||
[](BackwardFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
80
tensorflow/c/experimental/gradients/tape/tape_operation.h
Normal file
80
tensorflow/c/experimental/gradients/tape/tape_operation.h
Normal file
@ -0,0 +1,80 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
class TapeOperation : public AbstractOperation {
|
||||
public:
|
||||
explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&);
|
||||
void Release() override;
|
||||
Status Reset(const char* op, const char* raw_device_name) override;
|
||||
const string& Name() const override;
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values) override;
|
||||
AbstractOperation* GetBackingOperation();
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kTape;
|
||||
}
|
||||
~TapeOperation() override;
|
||||
|
||||
private:
|
||||
AbstractOperation* parent_op_;
|
||||
ForwardOperation forward_op_;
|
||||
Tape* tape_;
|
||||
const GradientRegistry& registry_;
|
||||
};
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
@ -1,3 +1,6 @@
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Experimental ops. These will eventually be replaced by machine-generated versions.
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -19,7 +22,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/c/eager:tracing_utils",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
@ -40,8 +43,8 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:tracing_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
@ -61,7 +64,41 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/c/eager:tracing_utils",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
hdrs = [
|
||||
"array_ops.h",
|
||||
"math_ops.h",
|
||||
"nn_ops.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":array_ops",
|
||||
":math_ops",
|
||||
":nn_ops",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"array_ops.h",
|
||||
"math_ops.h",
|
||||
"nn_ops.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -14,9 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
using tensorflow::tracing::MaybeSetOpName;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
@ -26,28 +28,58 @@ Status Identity(AbstractContext* ctx,
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(identity_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return identity_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status IdentityN(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_n_op->Reset("IdentityN", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(identity_n_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(identity_n_op->AddInputList(inputs));
|
||||
int num_retvals = inputs.size();
|
||||
return identity_n_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status ZerosLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr z_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(z_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(z_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(z_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return z_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Shape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr shape_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(shape_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExpandDims(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("ExpandDims", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name));
|
||||
TF_RETURN_IF_ERROR(op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(op->AddInput(inputs[1]));
|
||||
int num_retvals = 1;
|
||||
return op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
@ -27,10 +26,22 @@ Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status IdentityN(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status ZerosLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Shape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status ExpandDims(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,22 +16,21 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
using tensorflow::tracing::MaybeSetOpName;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
|
||||
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
|
||||
if (isa<TracingOperation>(mul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(mul_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
|
||||
int num_retvals = 1;
|
||||
@ -55,12 +54,7 @@ Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(add_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
|
||||
|
||||
@ -69,18 +63,26 @@ Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sub_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sub_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a = false, bool transpose_b = false) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(matmul_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
|
||||
|
||||
@ -96,15 +98,79 @@ Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr neg_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
|
||||
if (isa<TracingOperation>(neg_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(neg_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(neg_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
return neg_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sum_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sum_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr div_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(div_op->Execute(
|
||||
outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr exp_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(exp_op->Reset("Exp", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(exp_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(exp_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
return exp_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -22,18 +22,43 @@ namespace tensorflow {
|
||||
namespace ops {
|
||||
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Conj(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b);
|
||||
|
||||
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -15,24 +15,22 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
using tensorflow::tracing::MaybeSetOpName;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
// Softmax Loss given scores and labels, used by the SoftMaxLossGradient
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sm_loss_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(sm_loss_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sm_loss_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sm_loss_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels
|
||||
|
||||
@ -49,12 +47,7 @@ Status ReluGrad(AbstractContext* ctx,
|
||||
AbstractOperationPtr relugrad_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(relugrad_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(relugrad_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(relugrad_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads
|
||||
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs
|
||||
|
||||
@ -63,5 +56,18 @@ Status ReluGrad(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Relu(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(relu_op->Reset("Relu", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(relu_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(relu_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(relu_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -18,12 +18,11 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
Status SparseSoftmaxCrossEntropyLoss(
|
||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
@ -31,6 +30,10 @@ Status ReluGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Relu(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# Experimental SavedModel C APIs for TensorFlow. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
@ -62,13 +64,21 @@ cc_library(
|
||||
":function_metadata",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -81,15 +91,24 @@ cc_library(
|
||||
":signature_def_function_metadata",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_metadata",
|
||||
srcs = [
|
||||
"signature_def_function_metadata.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"signature_def_function_metadata.h",
|
||||
],
|
||||
deps = [
|
||||
":tensor_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -138,11 +157,13 @@ cc_library(
|
||||
":saved_model_api",
|
||||
":saved_model_utils",
|
||||
":signature_def_function",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
@ -151,7 +172,6 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -213,6 +233,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -256,6 +277,20 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_spec",
|
||||
srcs = [
|
||||
"tensor_spec.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tensor_spec.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tf_concrete_function_loading_test",
|
||||
srcs = [
|
||||
|
@ -43,8 +43,8 @@ class ConcreteFunction {
|
||||
virtual ~ConcreteFunction() = default;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) = 0;
|
||||
virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) const = 0;
|
||||
|
||||
virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
|
||||
};
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
@ -300,80 +301,70 @@ nodes {
|
||||
|
||||
TEST(ObjectGraphTraversalTest, Success) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_NE(nullptr, obj);
|
||||
EXPECT_EQ(obj->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(obj->user_object().identifier(), "_generic_user_object");
|
||||
absl::optional<int> node = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_TRUE(node.has_value());
|
||||
EXPECT_EQ(*node, 1);
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, ObjectNotFound) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("bar", object_graph);
|
||||
EXPECT_EQ(nullptr, obj);
|
||||
absl::optional<int> node = internal::FindNodeAtPath("bar", object_graph);
|
||||
EXPECT_FALSE(node.has_value());
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, CaseSensitiveMismatch) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("FOO", object_graph);
|
||||
EXPECT_EQ(nullptr, obj);
|
||||
absl::optional<int> node = internal::FindNodeAtPath("FOO", object_graph);
|
||||
EXPECT_FALSE(node.has_value());
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, NestedObjectFound) {
|
||||
SavedObjectGraph object_graph =
|
||||
ParseSavedObjectGraph(kSingleChildFooWithFuncBar);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("foo.bar", object_graph);
|
||||
ASSERT_NE(nullptr, obj);
|
||||
EXPECT_EQ(obj->kind_case(), SavedObject::kFunction);
|
||||
EXPECT_EQ(obj->function().concrete_functions_size(), 1);
|
||||
EXPECT_EQ(obj->function().concrete_functions(0), "__inference_my_func_5");
|
||||
absl::optional<int> node = internal::FindNodeAtPath("foo.bar", object_graph);
|
||||
ASSERT_TRUE(node.has_value());
|
||||
EXPECT_EQ(*node, 2);
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, MultiplePathsAliasSameObject) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kMultiplePathsToChild);
|
||||
const SavedObject* foo_baz =
|
||||
absl::optional<int> foo_baz_node =
|
||||
internal::FindNodeAtPath("foo.baz", object_graph);
|
||||
ASSERT_NE(nullptr, foo_baz);
|
||||
EXPECT_EQ(foo_baz->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_baz->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(foo_baz_node.has_value());
|
||||
EXPECT_EQ(*foo_baz_node, 4);
|
||||
|
||||
const SavedObject* bar_wombat =
|
||||
absl::optional<int> bar_wombat_node =
|
||||
internal::FindNodeAtPath("bar.wombat", object_graph);
|
||||
ASSERT_NE(nullptr, bar_wombat);
|
||||
EXPECT_EQ(bar_wombat->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(bar_wombat->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(bar_wombat_node.has_value());
|
||||
EXPECT_EQ(*bar_wombat_node, 4);
|
||||
|
||||
EXPECT_EQ(foo_baz, bar_wombat);
|
||||
EXPECT_EQ(*foo_baz_node, *bar_wombat_node);
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, CyclesAreOK) {
|
||||
SavedObjectGraph object_graph =
|
||||
ParseSavedObjectGraph(kCycleBetweenParentAndChild);
|
||||
const SavedObject* foo = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_NE(nullptr, foo);
|
||||
EXPECT_EQ(foo->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo->user_object().identifier(), "_generic_user_object");
|
||||
absl::optional<int> foo = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_TRUE(foo.has_value());
|
||||
EXPECT_EQ(*foo, 1);
|
||||
|
||||
const SavedObject* foo_bar =
|
||||
absl::optional<int> foo_bar =
|
||||
internal::FindNodeAtPath("foo.bar", object_graph);
|
||||
ASSERT_NE(nullptr, foo_bar);
|
||||
EXPECT_EQ(foo_bar->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_bar->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(foo_bar.has_value());
|
||||
EXPECT_EQ(*foo_bar, 3);
|
||||
|
||||
const SavedObject* foo_bar_parent =
|
||||
absl::optional<int> foo_bar_parent =
|
||||
internal::FindNodeAtPath("foo.bar.parent", object_graph);
|
||||
ASSERT_NE(nullptr, foo_bar_parent);
|
||||
EXPECT_EQ(foo_bar_parent->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_bar_parent->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(foo_bar_parent.has_value());
|
||||
EXPECT_EQ(*foo_bar_parent, 1);
|
||||
|
||||
const SavedObject* foo_bar_parent_bar =
|
||||
absl::optional<int> foo_bar_parent_bar =
|
||||
internal::FindNodeAtPath("foo.bar.parent.bar", object_graph);
|
||||
ASSERT_NE(nullptr, foo_bar_parent_bar);
|
||||
EXPECT_EQ(foo_bar_parent_bar->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_bar_parent_bar->user_object().identifier(),
|
||||
"_generic_user_object");
|
||||
ASSERT_TRUE(foo_bar_parent_bar.has_value());
|
||||
EXPECT_EQ(*foo_bar_parent_bar, 3);
|
||||
|
||||
EXPECT_EQ(foo, foo_bar_parent);
|
||||
EXPECT_EQ(foo_bar, foo_bar_parent_bar);
|
||||
EXPECT_EQ(*foo, *foo_bar_parent);
|
||||
EXPECT_EQ(*foo_bar, *foo_bar_parent_bar);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# This package contains written convenience helpers for Eager Operations
|
||||
# used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these.
|
||||
load(
|
||||
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
# This package contains classes corresponding to Revived SavedObjectGraph types
|
||||
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
|
||||
package(
|
||||
@ -8,6 +10,25 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "asset",
|
||||
srcs = [
|
||||
"asset.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"asset.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant",
|
||||
srcs = [
|
||||
@ -28,6 +49,106 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flat_tensor_function",
|
||||
srcs = [
|
||||
"flat_tensor_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"flat_tensor_function.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "partially_revived_objects",
|
||||
srcs = [
|
||||
"partially_revived_objects.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"partially_revived_objects.h",
|
||||
],
|
||||
deps = [
|
||||
":asset",
|
||||
":constant",
|
||||
":restored_resource",
|
||||
":restored_resource_revival_state",
|
||||
":revived_objects",
|
||||
":tf_concrete_function",
|
||||
":tf_concrete_function_revival_state",
|
||||
":tf_signature_def_function",
|
||||
":tf_signature_def_function_revival_state",
|
||||
":variable",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restored_resource",
|
||||
srcs = [
|
||||
"restored_resource.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"restored_resource.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
":tf_concrete_function",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restored_resource_revival_state",
|
||||
hdrs = [
|
||||
"restored_resource_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_concrete_function_revival_state",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "revived_objects",
|
||||
hdrs = [
|
||||
"revived_objects.h",
|
||||
],
|
||||
deps = [
|
||||
":asset",
|
||||
":constant",
|
||||
":restored_resource",
|
||||
":tf_concrete_function",
|
||||
":tf_signature_def_function",
|
||||
":variable",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable",
|
||||
srcs = [
|
||||
@ -45,6 +166,8 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -68,7 +191,7 @@ cc_library(
|
||||
"tf_concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
":flat_tensor_function",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
@ -81,3 +204,55 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_concrete_function_revival_state",
|
||||
hdrs = [
|
||||
"tf_concrete_function_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_concrete_function",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_signature_def_function",
|
||||
srcs = [
|
||||
"tf_signature_def_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tf_signature_def_function.h",
|
||||
],
|
||||
deps = [
|
||||
":flat_tensor_function",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_signature_def_function_revival_state",
|
||||
hdrs = [
|
||||
"tf_signature_def_function_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_signature_def_function",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Asset::Asset(ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)) {}
|
||||
|
||||
Status Asset::Create(ImmediateExecutionContext* ctx,
|
||||
const std::string& saved_model_dir,
|
||||
const std::string& asset_filename,
|
||||
std::unique_ptr<Asset>* output) {
|
||||
std::string abs_path =
|
||||
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
|
||||
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create scalar string tensor for Asset at path ", abs_path);
|
||||
}
|
||||
|
||||
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
|
||||
output->reset(new Asset(std::move(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Asset : public TensorHandleConvertible {
|
||||
public:
|
||||
static Status Create(ImmediateExecutionContext* ctx,
|
||||
const std::string& saved_model_dir,
|
||||
const std::string& asset_filename,
|
||||
std::unique_ptr<Asset>* output);
|
||||
|
||||
// Asset is movable, but not copyable.
|
||||
Asset(Asset&& other) = default;
|
||||
Asset& operator=(Asset&& other) = default;
|
||||
|
||||
~Asset() override = default;
|
||||
|
||||
private:
|
||||
explicit Asset(ImmediateTensorHandlePtr handle);
|
||||
Asset(const Asset&) = delete;
|
||||
Asset& operator=(const Asset&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
@ -0,0 +1,91 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
FlatTensorFunction::FlatTensorFunction(
|
||||
const std::string& name, std::vector<ImmediateTensorHandlePtr> captures,
|
||||
ImmediateExecutionContext* ctx)
|
||||
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
|
||||
|
||||
FlatTensorFunction::~FlatTensorFunction() {
|
||||
Status status = ctx_->RemoveFunction(name_);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
Status FlatTensorFunction::Create(
|
||||
const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
|
||||
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
||||
std::vector<ImmediateTensorHandlePtr> owned_captures;
|
||||
owned_captures.reserve(captures.size());
|
||||
for (ImmediateExecutionTensorHandle* capture : captures) {
|
||||
capture->Ref();
|
||||
owned_captures.push_back(ImmediateTensorHandlePtr(capture));
|
||||
}
|
||||
|
||||
out->reset(new FlatTensorFunction(function_def->signature().name(),
|
||||
std::move(owned_captures), ctx));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status FlatTensorFunction::MakeCallOp(
|
||||
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
|
||||
out->reset(ctx_->CreateOperation());
|
||||
// In eager mode, TF2 python executes functions by constructing an op with
|
||||
// the name of the functiondef:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
|
||||
// In graph mode, we create a PartitionedCallOp instead:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
|
||||
|
||||
// TODO(bmzhao): After discussing with Allen, we should execute this via a
|
||||
// PartitionedCallOp for compatibility with "tooling that assumes functions in
|
||||
// graphs are PartitionedCallOps".
|
||||
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
|
||||
|
||||
// Adding the user-provided inputs to the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
|
||||
|
||||
absl::Span<AbstractTensorHandle* const> captures(
|
||||
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
|
||||
captures_.size());
|
||||
|
||||
// Adding the captures of the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user