summary_op_bm and added BUILD for summary_op_benchmark_test
This commit is contained in:
Daniel Nguyen 2020-08-11 20:00:42 +00:00
commit d41aa7a5b9
3946 changed files with 177968 additions and 81322 deletions

View File

@ -18,8 +18,10 @@
# #
# Compiler options: # Compiler options:
# cuda_clang: Use clang when building CUDA code. # cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options # c++17: Build with C++17 options (links with libc++)
# c++1z: Build with C++17 options # c++1z: Build with C++17 options (links with libc++)
# c++17_gcc: Build with C++17 options (links with stdlibc++)
# c++1z_gcc: Build with C++17 options (links with stdlibc++)
# avx_linux: Build with avx instruction set on linux. # avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux. # avx2_linux: Build with avx2 instruction set on linux.
# native_arch_linux: Build with instruction sets available to the host machine on linux # native_arch_linux: Build with instruction sets available to the host machine on linux
@ -28,6 +30,7 @@
# #
# Other build options: # Other build options:
# short_logs: Only log errors during build, skip warnings. # short_logs: Only log errors during build, skip warnings.
# verbose_logs: Show all compiler warnings during build.
# monolithic: Build all TF C++ code into a single shared object. # monolithic: Build all TF C++ code into a single shared object.
# dynamic_kernels: Try to link all kernels dynamically (experimental). # dynamic_kernels: Try to link all kernels dynamically (experimental).
# libc++: Link against libc++ instead of stdlibc++ # libc++: Link against libc++ instead of stdlibc++
@ -78,7 +81,16 @@
# elinux: General Embedded Linux options shared by all flavors. # elinux: General Embedded Linux options shared by all flavors.
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support. # elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support. # elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
#
# Release build options (for all operating systems)
# release_common: Common options for all builds on all operating systems.
# release_windows_common: Common options for all builds on Windows.
# release_gpu_common: Common options for GPU builds on Linux and Windows.
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
# Allow builds using libc++ as a linker library # Allow builds using libc++ as a linker library
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file # This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
@ -155,14 +167,29 @@ build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool. # config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true
build:mkl_threadpool --define=build_with_mkl_opensource=true
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt build:mkl_threadpool -c opt
# Config setting to build with oneDNN and without the binary blob
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only -c opt
# This config refers to building with CUDA available. It does not necessarily # This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels. # mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true build:using_cuda --define=using_cuda=true
build:using_cuda --action_env TF_NEED_CUDA=1 build:using_cuda --action_env TF_NEED_CUDA=1
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
# Enable the mlir generated GPU kernels only for cuda builds.
build --define=tensorflow_enable_mlir_generated_gpu_kernels=0
# This is a more specific option, so it takes precedence over the line above for cuda builds.
build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1
# This config refers to building CUDA op kernels with nvcc. # This config refers to building CUDA op kernels with nvcc.
build:cuda --config=using_cuda build:cuda --config=using_cuda
build:cuda --define=using_cuda_nvcc=true build:cuda --define=using_cuda_nvcc=true
@ -253,6 +280,8 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++ build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17 build:c++1z --config=c++17
build:c++17_gcc --cxxopt=-std=c++1z
build:c++1z_gcc --config=c++17_gcc
# Enable using platform specific build settings, except when cross-compiling for # Enable using platform specific build settings, except when cross-compiling for
# mobile platforms. # mobile platforms.
@ -322,6 +351,8 @@ build:windows --distinct_host_configuration=false
# Suppress all warning messages. # Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING build:short_logs --output_filter=DONT_MATCH_ANYTHING
build:verbose_logs --output_filter=
build --config=short_logs
# Instruction set optimizations # Instruction set optimizations
# TODO(gunan): Create a feature in toolchains for avx/avx2 to # TODO(gunan): Create a feature in toolchains for avx/avx2 to
@ -341,7 +372,6 @@ build --config=v2
test --config=v2 test --config=v2
# Enable XLA # Enable XLA
build:xla --action_env=TF_ENABLE_XLA=1
build:xla --define=with_xla_support=true build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS # BEGIN TF REMOTE BUILD EXECUTION OPTIONS
@ -534,3 +564,43 @@ try-import %workspace%/.tf_configure.bazelrc
# Put user-specific options in .bazelrc.user # Put user-specific options in .bazelrc.user
try-import %workspace%/.bazelrc.user try-import %workspace%/.bazelrc.user
# Here are bazelrc configs for release builds
build:release_common --config=opt
build:release_common --config=v2
build:release_common --distinct_host_configuration=false
build:release_common --action_env TF_CONFIGURE_IOS="0"
build:release_cpu_linux --config=release_common
build:release_cpu_linux --config=avx_linux
# We use the same toolchain for CPU/GPU packages.
# Did not add this to the defaults in case this changes.
build:release_cpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_cpu_macos --config=release_common
build:release_cpu_macos --config=avx_linux
build:release_gpu_common --config=release_common
build:release_gpu_common --config=cuda
build:release_gpu_common --config=tensorrt
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_linux --config=release_gpu_common
build:release_gpu_linux --config=avx_linux
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_windows_common --config=release_common
build:release_windows_common --define=no_tensorflow_py_deps=true
build:release_windows_common --announce_rc
build:release_cpu_windows --config=release_windows_common
build:release_gpu_windows --config=release_windows_common

View File

@ -123,20 +123,21 @@ Build Type | Status
### Community Supported Builds ### Community Supported Builds
Build Type | Status | Artifacts Build Type | Status | Artifacts
----------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- ----------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/) **Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/) **Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) **Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) **Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) **Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) **Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) **Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux aarch64 CPU** Nightly <br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) **Linux aarch64 CPU** Nightly <br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux aarch64 CPU** Stable Release | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
## Resources ## Resources

View File

@ -11,10 +11,28 @@
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and * C-API functions `TF_StringDecode`, `TF_StringEncode`, and
`TF_StringEncodedSize` are no longer relevant and have been removed; see `TF_StringEncodedSize` are no longer relevant and have been removed; see
core/platform/ctstring.h for string access/modification in C. core/platform/ctstring.h for string access/modification in C.
* In batching library, rename parameter * Removed `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2.
SharedBatchScheduler::QueueOptions::max_batch_size to a more accurate name * `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are
(input_batch_size_limit) for a recent feature to enable split of large batch now hidden. These modules are not part of TensorFlow public API.
sizes. * A major refactoring of the internals of the Keras Functional API may affect code that is relying on certain internal details:
* Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`.
* Code that is overly dependent on the exact names attached to symbolic tensors (e.g. assumes there will be ":0" at the end of the inputs, treats names as unique identifiers instead of using `tensor.ref()`, etc.)
* Code that uses `get_concrete_function` to trace Keras symbolic inputs directly should switch to building matching `tf.TensorSpec`s directly and tracing the `TensorSpec` objects.
* Code that relies on the exact number and names of the op layers that TensorFlow operations were converted into. These may have changed.
* Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy.
* Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values.
* Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now.
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
* Start enforcing input shape assumptions when calling Functional API Keras
models. This may potentially break some users, in case there is a mismatch
between the shape used when creating `Input` objects in a Functional model,
and the shape of the data passed to that model. You can fix this mismatch by
either calling the model with correctly-shaped data, or by relaxing `Input`
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
## Known Caveats ## Known Caveats
@ -24,6 +42,8 @@
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX> * <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER> * <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See tensorflow/python/ops/numpy_ops/README.md for details of what are supported and what are the differences with NumPy.
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
@ -31,36 +51,106 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE> * <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA> * <NOTES SHOULD BE GROUPED PER AREA>
* TF Core: * TF Core:
* <ADD RELEASE NOTES HERE> * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
* `tf.Tensor` is now a subclass of `typing.Generic`, allowing type annotations type annotation for variables representing a Tensor or a value that can be
to be parameterized by dtype: `tf.Tensor[tf.Int32]`. This requires Python 3, converted to Tensor by `tf.convert_to_tensor`.
and will become fully compatible with static type checkers in the future. * Calling ops with a python constants or numpy values is now consistent with
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
and `__invert__` now support non-`bool` arguments and apply the
corresponding bitwise ops. `bool` arguments continue to be supported and
dispatch to logical ops. This brings them more in line with Python and NumPy
benavior.
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
the same sparsity pattern, but with new provided values. It is similar to
the `with_values` function of `RaggedTensor`.
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
* `tf.data`: * `tf.data`:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
to register a dataset with the tf.data service, and another process to
consume data from the dataset.
* Added support for tf.data service dispatcher fault tolerance. To enable
fault tolerance, configure a `work_dir` when running your dispatcher
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
store its state to `work_dir`, so that on restart it can continue from its
previous state after restart.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is * Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified. the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a determinstic
version of `sample_distorted_bounding_box` op. Given the same seed, these
stateless functions/ops produce the same results independent of how many
times the function is called, and independent of global seed settings.
* `tf.distribute`: * `tf.distribute`:
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* `tf.keras`: * `tf.keras`:
* <ADD RELEASE NOTES HERE> * Improvements from the functional API refactoring:
* `tf.function`/AutoGraph: * Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models.
* <ADD RELEASE NOTES HERE> * Functional model construction should be ~8-10% faster on average.
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument.
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale`
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` parameter to FTRL optimizer to match paper.
* Added `mobilenet_v3` to keras application model.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the previous
iteration. These types of loops must run at least one iteration, and will
raise a runtime error otherwise.
Example:
```
for batch in data:
outputs = train_step(batch)
tf.print('final outputs', outputs)
```
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info.
* `tf.lite`: * `tf.lite`:
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* `tf.random`: * `tf.random`:
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* Math and Linear Algebra: * Math and Linear Algebra:
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* TPU Enhancements: * TPU Enhancements:
* Added support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* XLA Support: * XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* Tracing and Debugging: * Tracing and Debugging:
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* Other: * Other:
* We have replaced uses of "whitelist" with "allowlist" where possible. * We have replaced uses of "whitelist" and "blacklist" with "allowlist"
Please see https://developers.google.com/style/word-list#blacklist for more and "denylist" where possible. Please see
context. https://developers.google.com/style/word-list#blacklist for more context.
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
## Thanks to our Contributors ## Thanks to our Contributors
@ -71,19 +161,206 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.0 # Release 2.3.0
## Breaking Changes ## Major Features and Improvements
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
* `tf.image.extract_glimpse` has been updated to correctly process the case In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this * [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of * [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your models memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved * Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
models will not be impacted.
* TFLite now properly supports dynamic shapes during conversion and inference. Weve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
## Breaking Changes
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
* `tf.data`
* Makes the following (breaking) changes to the `tf.data`.
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
* `tf.keras`
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
## Known Caveats
* `tf.lite`
* Keras-based LSTM models must be converted with an explicit batch size in the input layer.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* Mutable tables now restore checkpointed values when loaded from SavedModel.
### TF Core:
* Set `tf2_behavior` to 1 to enable V2 for early loading cases.
* Add `execute_fn_for_device function` to dynamically choose the implementation based on underlying device placement.
* Eager:
* Add `reduce_logsumexp` benchmark with experiment compile.
* Give `EagerTensor`s a meaningful `__array__` implementation.
* Add another version of defun matmul for performance analysis.
* `tf.function`/AutoGraph:
* `AutoGraph` now includes into TensorFlow loops any variables that are closed over by local functions. Previously, such variables were sometimes incorrectly ignored.
* functions returned by the `get_concrete_function` method of `tf.function` objects can now be called with arguments consistent with the original arguments or type specs passed to `get_concrete_function`. This calling convention is now the preferred way to use concrete functions with nested values and composite tensors. Please check the [guide](https://www.tensorflow.org/guide/concrete_function) for more details on `concrete_ function`.
* Update `tf.function`'s `experimental_relax_shapes` to handle composite tensors appropriately.
* Optimize `tf.function` invocation, by removing redundant list converter.
* `tf.function` will retrace when called with a different variable instead of simply using the `dtype` & `shape`.
* [Improve support](https://github.com/tensorflow/tensorflow/issues/33862) for dynamically-sized TensorArray inside `tf.function`.
* `tf.math`:
* Narrow down `argmin`/`argmax` contract to always return the smallest index for ties.
* `tf.math.reduce_variance` and `tf.math.reduce_std` return correct computation for complex types and no longer support integer types.
* Add Bessel functions of order 0,1 to `tf.math.special`.
* `tf.divide` now always returns a tensor to be consistent with documentation and other APIs.
* `tf.image`:
* Replaced [`tf.image.non_max_suppression_padded`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/image/non_max_suppression_padded?hl=en) with a new implementation that supports batched inputs, which is considerably faster on TPUs and GPUs. Boxes with area=0 will be ignored. Existing usage with single inputs should still work as before.
* `tf.linalg`
* Add `tf.linalg.banded_triangular_solve`.
* `tf.random`:
* Add `tf.random.stateless_parameterized_truncated_normal`.
* `tf.ragged`:
* Add `tf.ragged.cross` and `tf.ragged.cross_hashed` operations.
* `tf.RaggedTensor`:
* `RaggedTensor.to_tensor()` now preserves static shape.
* Add `tf.strings.format()` and `tf.print()` to support RaggedTensors.
* `tf.saved_model`:
* `@tf.function` from SavedModel no longer ignores args after a `RaggedTensor` when selecting the concrete function to run.
* Fix save model issue for ops with a list of functions.
* Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights.
* Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights.
* Mutable tables now restore checkpointed values when loaded from SavedModel.
* GPU
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
* Others
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.
* `tf.custom_gradient` can now be applied to functions that accept nested structures of `tensors` as inputs (instead of just a list of tensors). Note that Python structures such as tuples and lists now won't be treated as tensors, so if you still want them to be treated that way, you need to wrap them with `tf.convert_to_tensor`.
* No lowering on gradient case op when input is `DeviceIndex` op.
* Extend the ragged version of `tf.gather` to support `batch_dims` and `axis` args.
* Update `tf.map_fn` to support RaggedTensors and SparseTensors.
* Deprecate `tf.group`. It is not useful in eager mode.
* Add CPU and GPU implementation of modified variation of [`FTRL`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrl)/[`FTRLV2`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrlV2) that can triggerred by `multiply_linear_by_lr` allowing a learning rate of zero.
### `tf.data`:
* `tf.data.experimental.dense_to_ragged_batch` works correctly with tuples.
* `tf.data.experimental.dense_to_ragged_batch` to output variable ragged rank.
* `tf.data.experimental.cardinality` is now a method on `tf.data.Dataset`.
* `tf.data.Dataset` now supports `len(Dataset)` when the cardinality is finite.
### `tf.distribute`:
* Expose experimental [`tf.distribute.DistributedDataset`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedDataset?hl=en) and [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator) to distribute input data when using `tf.distribute` to scale training on multiple devices.
* Added a [`get_next_as_optional`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en#get_next_as_optional) method for [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en) class to return a `tf.experimental.Optional` instance that contains the next value for all replicas or none instead of raising an out of range error. Also see *new* [guide on input distribution](https://www.tensorflow.org/tutorials/distribute/input).
* Allow var.assign on MirroredVariables with aggregation=NONE in replica context. Previously this would raise an error. We now allow this because many users and library writers find using `.assign` in replica context to be more convenient, instead of having to use `Strategy.extended.update` which was the previous way of updating variables in this situation.
* `tf.distribute.experimental.MultiWorkerMirroredStrategy` adds support for partial batches. Workers running out of data now continue to participate in the training with empty inputs, instead of raising an error. Learn more about [partial batches here](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
* Improve the performance of reading metrics eagerly under `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
### `tf.keras`:
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
* Added **categorical data** processing layers:
* `IntegerLookup` & `StringLookup`: build an index of categorical feature values
* `CategoryEncoding`: turn integer-encoded categories into one-hot, multi-hot, or tf-idf encoded representations
* `CategoryCrossing`: create new categorical features representing co-occurrences of previous categorical feature values
* `Hashing`: the hashing trick, for large-vocabulary categorical features
* `Discretization`: turn continuous numerical features into categorical features by binning their values
* Improved **image preprocessing** layers: `CenterCrop`, `Rescaling`
* Improved **image augmentation** layers: `RandomCrop`, `RandomFlip`, `RandomTranslation`, `RandomRotation`, `RandomHeight`, `RandomWidth`, `RandomZoom`, `RandomContrast`
* Improved **`TextVectorization`** layer, which handles string tokenization, n-gram generation, and token encoding
* The `TextVectorization` layer now accounts for the mask_token as part of the vocabulary size when output_mode='int'. This means that, if you have a max_tokens value of 5000, your output will have 5000 unique values (not 5001 as before).
* Change the return value of `TextVectorization.get_vocabulary()` from `byte` to `string`. Users who previously were calling 'decode' on the output of this method should no longer need to do so.
* Introduce new Keras dataset generation utilities :
* **[`image_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory)** is a utility based on `tf.data.Dataset`, meant to replace the legacy `ImageDataGenerator`. It takes you from a structured directory of images to a labeled dataset, in one function call. Note that it doesn't perform image data augmentation (which is meant to be done using preprocessing layers).
* **[`text_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory)** takes you from a structured directory of text files to a labeled dataset, in one function call.
* **[`timeseries_dataset_from_array`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/timeseries_dataset_from_array)** is a `tf.data.Dataset`-based replacement of the legacy `TimeseriesGenerator`. It takes you from an array of timeseries data to a dataset of shifting windows with their targets.
* Added [`experimental_steps_per_execution`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/Model?hl=en#compile)
arg to `model.compile` to indicate the number of batches to run per `tf.function` call. This can speed up Keras Models on TPUs up to 3x.
* Extends `tf.keras.layers.Lambda` layers to support multi-argument lambdas, and keyword arguments when calling the layer.
* Functional models now get constructed if *any* tensor in a layer call's arguments/keyword arguments comes from a keras input. Previously the functional api would only work if all of the elements in the first argument to the layer came from a keras input.
* Clean up `BatchNormalization` layer's `trainable` property to act like standard python state when it's used inside `tf.functions` (frozen at tracing time), instead of acting like a pseudo-variable whose updates *kind of sometimes* get reflected in already-traced `tf.function` traces.
* Add the `Conv1DTranspose` layer.
* Refine the semantics of `SensitivitySpecificityBase` derived metrics. See the updated API docstrings for [`tf.keras.metrics.SensitivityAtSpecificity`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SensitivityAtSpecificity) and [`tf.keras.metrics.SpecificityAtSensitivty`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SpecificityAtSensitivity).
### `tf.lite`:
* Converter
* Restored `inference_input_type` and `inference_output_type` flags in TF 2.x TFLiteConverter (backward compatible with TF 1.x) to support integer (tf.int8, tf.uint8) input and output types in post training full integer quantized models.
* Added support for converting and resizing models with dynamic (placeholder) dimensions. Previously, there was only limited support for dynamic batch size, and even that did not guarantee that the model could be properly resized at runtime.
* Enabled experimental support for a new quantization mode with 16-bit activations and 8-bit weights. See `lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8`.
* CPU
* Fix an issue w/ dynamic weights and `Conv2D` on x86.
* Add a runtime Android flag for enabling `XNNPACK` for optimized CPU performance.
* Add a runtime iOS flag for enabling `XNNPACK` for optimized CPU performance.
* Add a compiler flag to enable building a TFLite library that applies `XNNPACK` delegate automatically when the model has a `fp32` operation.
* GPU
* Allow GPU acceleration starting with internal graph nodes
* Experimental support for quantized models with the Android GPU delegate
* Add GPU delegate whitelist.
* Rename GPU whitelist -> compatibility (list).
* Improve GPU compatibility list entries from crash reports.
* NNAPI
* Set default value for `StatefulNnApiDelegate::Options::max_number_delegated_partitions` to 3.
* Add capability to disable `NNAPI` CPU and check `NNAPI` Errno.
* Fix crashes when using `NNAPI` with target accelerator specified with model containing Conv2d or FullyConnected or LSTM nodes with quantized weights.
* Fix `ANEURALNETWORKS_BAD_DATA` execution failures with `sum`/`max`/`min`/`reduce` operations with `scalar` inputs.
* Hexagon
* TFLite Hexagon Delegate out of experimental.
* Experimental `int8` support for most hexagon ops.
* Experimental per-channel quant support for `conv` in Hexagon delegate.
* Support dynamic batch size in C++ API.
* CoreML
* Opensource CoreML delegate
* Misc
* Enable building Android TFLite targets on Windows
* Add support for `BatchMatMul`.
* Add support for `half_pixel_centers` with `ResizeNearestNeighbor`.
* Add 3D support for `BatchToSpaceND`.
* Add 5D support for `BroadcastSub`, `Maximum`, `Minimum`, `Transpose` and `BroadcastDiv`.
* Rename `kTfLiteActRelu1` to `kTfLiteActReluN1To1`.
* Enable flex delegate on tensorflow.lite.Interpreter Python package.
* Add `Buckettize`, `SparseCross` and `BoostedTreesBucketize` to the flex whitelist.
* Add support for selective registration of flex ops.
* Add missing kernels for flex delegate whitelisted ops.
* Fix issue when using direct `ByteBuffer` inputs with graphs that have dynamic shapes.
* Fix error checking supported operations in a model containing `HardSwish`.
### Packaging Support
* Added `tf.sysconfig.get_build_info()`. Returns a dict that describes the build environment of the currently installed TensorFlow package, e.g. the NVIDIA CUDA and NVIDIA CuDNN versions used when TensorFlow was built.
### Profiler
* Fix a subtle use-after-free issue in `XStatVisitor::RefValue()`.
### TPU Enhancements
* Adds 3D mesh support in TPU configurations ops.
* Added TPU code for `FTRL` with `multiply_linear_by_lr`.
* Silently adds a new file system registry at `gstpu`.
* Support `restartType` in cloud tpu client.
* Depend on a specific version of google-api-python-client.
* Fixes apiclient import.
### Tracing and Debugging
* Add a `TFE_Py_Execute` traceme.
### XLA Support
* Implement stable `argmin` and `argmax`
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
902449@58880@bigcat_chen@ASIC, Abdul Baseer Khan, Abhineet Choudhary, Abolfazl Shahbazi, Adam Hillier, ag.ramesh, Agoniii, Ajay P, Alex Hoffman, Alexander Bayandin, Alexander Grund, Alexandre Abadie, Alexey Rogachevskiy, amoitra, Andrew Stevens, Angus-Luo, Anshuman Tripathy, Anush Elangovan, Artem Mavrin, Ashutosh Hathidara, autoih, Ayushman Kumar, ayushmankumar7, Bairen Yi, Bas Aarts, Bastian Eichenberger, Ben Barsdell, bhack, Bharat Raghunathan, Biagio Montaruli, Bigcat-Himax, blueyi, Bryan Cutler, Byambaa, Carlos Hernandez-Vaquero, Chen Lei, Chris Knorowski, Christian Clauss, chuanqiw, CuiYifeng, Daniel Situnayake, Daria Zhuravleva, Dayananda-V, Deven Desai, Devi Sandeep Endluri, Dmitry Zakharov, Dominic Jack, Duncan Riach, Edgar Liberis, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, Eugene Kuznetsov, Eugene Mikhantiev, Evgenii Zheltonozhskii, Fabio Di Domenico, Fausto Morales, Fei Sun, feihugis, Felix E. Klee, flyingcat, Frederic Bastien, Fredrik Knutsson, frreiss, fsx950223, ganler, Gaurav Singh, Georgios Pinitas, Gian Marco Iodice, Giorgio Arena, Giuseppe Rossini, Gregory Keith, Guozhong Zhuang, gurushantj, Hahn Anselm, Harald Husum, Harjyot Bagga, Hristo Vrigazov, Ilya Persky, Ir1d, Itamar Turner-Trauring, jacco, Jake Tae, Janosh Riebesell, Jason Zaman, jayanth, Jeff Daily, Jens Elofsson, Jinzhe Zeng, JLZ, Jonas Skog, Jonathan Dekhtiar, Josh Meyer, Joshua Chia, Judd, justkw, Kaixi Hou, Kam D Kasravi, Kamil Rakoczy, Karol Gugala, Kayou, Kazuaki Ishizaki, Keith Smiley, Khaled Besrour, Kilaru Yasaswi Sri Chandra Gandhi, Kim, Young Soo, Kristian Hartikainen, Kwabena W. Agyeman, Leslie-Fang, Leslie-Fang-Intel, Li, Guizi, Lukas Geiger, Lutz Roeder, M\U00E5Ns Nilsson, Mahmoud Abuzaina, Manish, Marcel Koester, Marcin Sielski, marload, Martin Jul, Matt Conley, mdfaijul, Meng, Peng, Meteorix, Michael Käufl, Michael137, Milan Straka, Mitchell Vitez, Ml-0, Mokke Meguru, Mshr-H, nammbash, Nathan Luehr, naumkin, Neeraj Bhadani, ngc92, Nick Morgan, nihui, Niranjan Hasabnis, Niranjan Yadla, Nishidha Panpaliya, Oceania2018, oclyke, Ouyang Jin, OverLordGoldDragon, Owen Lyke, Patrick Hemmer, Paul Andrey, Peng Sun, periannath, Phil Pearl, Prashant Dandriyal, Prashant Kumar, Rahul Huilgol, Rajan Singh, Rajeshwar Reddy T, rangjiaheng, Rishit Dagli, Rohan Reddy, rpalakkal, rposts, Ruan Kunliang, Rushabh Vasani, Ryohei Ikegami, Semun Lee, Seo-Inyoung, Sergey Mironov, Sharada Shiddibhavi, ShengYang1, Shraiysh Vaishay, Shunya Ueta, shwetaoj, Siyavash Najafzade, Srinivasan Narayanamoorthy, Stephan Uphoff, storypku, sunchenggen, sunway513, Sven-Hendrik Haase, Swapnil Parekh, Tamas Bela Feher, Teng Lu, tigertang, tomas, Tomohiro Ubukata, tongxuan.ltx, Tony Tonev, Tzu-Wei Huang, Téo Bouvard, Uday Bondhugula, Vaibhav Jade, Vijay Tadikamalla, Vikram Dattu, Vincent Abriou, Vishnuvardhan Janapati, Vo Van Nghia, VoVAllen, Will Battel, William D. Irons, wyzhao, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, xutianming, Yair Ehrenwald, Yasir Modak, Yasuhiro Matsumoto, Yixing Fu, Yong Tang, Yuan Tang, zhaozheng09, Zilin Zhu, zilinzhu, 张志豪
# Release 2.1.1 # Release 2.1.1
@ -210,7 +487,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
`Strategy.extended.update` and `Strategy.extended.update_non_slot`. `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in * Experimental support for shape invariants has been enabled in
`tf.function`. See the API docs for `tf.function`. See the API docs for
`tf.autograph.experimental.set_loop_options` for additonal info. `tf.autograph.experimental.set_loop_options` for additional info.
* AutoGraph error messages now exclude frames corresponding to APIs * AutoGraph error messages now exclude frames corresponding to APIs
internal to AutoGraph. internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more * Improve shape inference for `tf.function` input arguments to unlock more
@ -293,7 +570,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
also deterministic back-prop of bias-addition in Keras layers) to also deterministic back-prop of bias-addition in Keras layers) to
include when XLA JIT compilation is enabled. include when XLA JIT compilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment * Fix problem, when running on a CUDA GPU and when either environment
variable `TF_DETERMINSTIC_OPS` or environment variable variable `TF_DETERMINISTIC_OPS` or environment variable
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
configurations led to an exception with the message "No algorithm configurations led to an exception with the message "No algorithm
worked!" worked!"
@ -336,32 +613,86 @@ This release contains contributions from many people at Google, as well as:
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019. TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.
## Major Features and Improvements ## Major Features and Improvements
* The `tensorflow` pip package now includes GPU support by default (same as `tensorflow-gpu`) for both Linux and Windows. This runs on machines with and without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
* **Windows users:** Officially-released `tensorflow` Pip packages are now built with Visual Studio 2019 version 16.4 in order to take advantage of the new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new packages, you must install "Microsoft Visual C++ Redistributable for Visual Studio 2015, 2017 and 2019", available from Microsoft's website [here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads). * The `tensorflow` pip package now includes GPU support by default (same as
* This does not change the minimum required version for building TensorFlow from source on Windows, but builds enabling `EIGEN_STRONG_INLINE` can take over 48 hours to compile without this flag. Refer to `configure.py` for more information about `EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`. `tensorflow-gpu`) for both Linux and Windows. This runs on machines with and
* If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll` (new), are missing on your machine, `import tensorflow` will print a warning message. without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only
* The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6. packages can be downloaded at `tensorflow-cpu` for users who are concerned
* `tf.keras` about package size.
* Experimental support for mixed precision is available on GPUs and Cloud TPUs. See [usage guide](https://www.tensorflow.org/guide/keras/mixed_precision). * **Windows users:** Officially-released `tensorflow` Pip packages are now
* Introduced the `TextVectorization` layer, which takes as input raw strings and takes care of text standardization, tokenization, n-gram generation, and vocabulary indexing. See this [end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3). built with Visual Studio 2019 version 16.4 in order to take advantage of the
* Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be outside of the DistributionStrategy scope, as long as the model was constructed inside of a scope. new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new
* Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and `.predict` is available for Cloud TPUs, Cloud TPU, for all types of Keras models (sequential, functional and subclassing models). packages, you must install "Microsoft Visual C++ Redistributable for Visual
* Automatic outside compilation is now enabled for Cloud TPUs. This allows `tf.summary` to be used more conveniently with Cloud TPUs. Studio 2015, 2017 and 2019", available from Microsoft's website
* Dynamic batch sizes with DistributionStrategy and Keras are supported on Cloud TPUs. [here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads).
* Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in addition to `tf.data.Dataset`. * This does not change the minimum required version for building
* Keras reference implementations for many popular models are available in the TensorFlow [Model Garden](https://github.com/tensorflow/models/tree/master/official). TensorFlow from source on Windows, but builds enabling
* `tf.data` `EIGEN_STRONG_INLINE` can take over 48 hours to compile without this
* Changes rebatching for `tf.data datasets` + DistributionStrategy for better performance. Note that the dataset also behaves slightly differently, in that the rebatched dataset cardinality will always be a multiple of the number of replicas. flag. Refer to `configure.py` for more information about
* `tf.data.Dataset` now supports automatic data distribution and sharding in distributed environments, including on TPU pods. `EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`.
* Distribution policies for `tf.data.Dataset` can now be tuned with 1. `tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2. `tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)` * If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll`
* `tf.debugging` (new), are missing on your machine, `import tensorflow` will print a
* Add `tf.debugging.enable_check_numerics()` and `tf.debugging.disable_check_numerics()` to help debugging the root causes of issues involving infinities and `NaN`s. warning message.
* `tf.distribute` * The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6.
* Custom training loop support on TPUs and TPU pods is avaiable through `strategy.experimental_distribute_dataset`, `strategy.experimental_distribute_datasets_from_function`, `strategy.experimental_run_v2`, `strategy.reduce`. * `tf.keras`
* Support for a global distribution strategy through `tf.distribute.experimental_set_strategy(),` in addition to `strategy.scope()`. * Experimental support for mixed precision is available on GPUs and Cloud
* `TensorRT` TPUs. See
* [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new) is now supported and enabled by default. This adds support for more TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D, MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the TensorFlow-TensorRT python conversion API is exported as `tf.experimental.tensorrt.Converter`. [usage guide](https://www.tensorflow.org/guide/keras/mixed_precision).
* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to "true" or "1", this environment variable makes `tf.nn.bias_add` operate deterministically (i.e. reproducibly), but currently only when XLA JIT compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or "1" also makes cuDNN convolution and max-pooling operate deterministically. This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in both the forward and backward directions when running on a CUDA-enabled GPU. * Introduced the `TextVectorization` layer, which takes as input raw
strings and takes care of text standardization, tokenization, n-gram
generation, and vocabulary indexing. See this
[end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3).
* Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be
outside of the DistributionStrategy scope, as long as the model was
constructed inside of a scope.
* Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and
`.predict` is available for Cloud TPUs, Cloud TPU, for all types of
Keras models (sequential, functional and subclassing models).
* Automatic outside compilation is now enabled for Cloud TPUs. This allows
`tf.summary` to be used more conveniently with Cloud TPUs.
* Dynamic batch sizes with DistributionStrategy and Keras are supported on
Cloud TPUs.
* Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in
addition to `tf.data.Dataset`.
* Keras reference implementations for many popular models are available in
the TensorFlow
[Model Garden](https://github.com/tensorflow/models/tree/master/official).
* `tf.data`
* Changes rebatching for `tf.data datasets` + DistributionStrategy for
better performance. Note that the dataset also behaves slightly
differently, in that the rebatched dataset cardinality will always be a
multiple of the number of replicas.
* `tf.data.Dataset` now supports automatic data distribution and sharding
in distributed environments, including on TPU pods.
* Distribution policies for `tf.data.Dataset` can now be tuned with 1.
`tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2.
`tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)`
* `tf.debugging`
* Add `tf.debugging.enable_check_numerics()` and
`tf.debugging.disable_check_numerics()` to help debugging the root
causes of issues involving infinities and `NaN`s.
* `tf.distribute`
* Custom training loop support on TPUs and TPU pods is available through
`strategy.experimental_distribute_dataset`,
`strategy.experimental_distribute_datasets_from_function`,
`strategy.experimental_run_v2`, `strategy.reduce`.
* Support for a global distribution strategy through
`tf.distribute.experimental_set_strategy(),` in addition to
`strategy.scope()`.
* `TensorRT`
* [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new)
is now supported and enabled by default. This adds support for more
TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D,
MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the
TensorFlow-TensorRT python conversion API is exported as
`tf.experimental.tensorrt.Converter`.
* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to
"true" or "1", this environment variable makes `tf.nn.bias_add` operate
deterministically (i.e. reproducibly), but currently only when XLA JIT
compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or
"1" also makes cuDNN convolution and max-pooling operate deterministically.
This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in
both the forward and backward directions when running on a CUDA-enabled GPU.
## Breaking Changes ## Breaking Changes
* Deletes `Operation.traceback_with_start_lines` for which we know of no usages. * Deletes `Operation.traceback_with_start_lines` for which we know of no usages.

View File

@ -260,6 +260,36 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "armeabi",
values = {"cpu": "armeabi"},
visibility = ["//visibility:public"],
)
config_setting(
name = "armeabi-v7a",
values = {"cpu": "armeabi-v7a"},
visibility = ["//visibility:public"],
)
config_setting(
name = "arm64-v8a",
values = {"cpu": "arm64-v8a"},
visibility = ["//visibility:public"],
)
selects.config_setting_group(
name = "arm_any",
match_any = [
":arm",
":armeabi",
":armeabi-v7a",
":arm64-v8a",
":linux_aarch64",
":linux_armhf",
],
)
config_setting( config_setting(
name = "freebsd", name = "freebsd",
values = {"cpu": "freebsd"}, values = {"cpu": "freebsd"},
@ -532,16 +562,14 @@ selects.config_setting_group(
package_group( package_group(
name = "internal", name = "internal",
packages = [ packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...", "//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...", "//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
"//tensorflow/...", "//tensorflow/...",
"//tensorflow_estimator/python/estimator/...", "//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...", "//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
], ],
) )

View File

@ -137,7 +137,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here. # TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs: for _s in _site_packages_dirs:
# Load first party dynamic kernels. # Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') _main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir): if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir) _ll.load_library(_main_dir)
@ -158,4 +158,23 @@ if hasattr(_current_module, 'keras'):
setattr(_current_module, "initializers", initializers) setattr(_current_module, "initializers", initializers)
# pylint: enable=undefined-variable # pylint: enable=undefined-variable
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
# __all__ PLACEHOLDER # __all__ PLACEHOLDER

View File

@ -147,7 +147,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here. # TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs: for _s in _site_packages_dirs:
# Load first party dynamic kernels. # Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') _main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir): if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir) _ll.load_library(_main_dir)
@ -156,4 +156,25 @@ if _running_from_pip_package():
if _fi.file_exists(_plugin_dir): if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir) _ll.load_library(_plugin_dir)
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
# __all__ PLACEHOLDER # __all__ PLACEHOLDER

View File

@ -213,6 +213,17 @@ tf_cuda_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "logging",
srcs = ["logging.cc"],
hdrs = ["logging.h"],
deps = [
":c_api_macros",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:stringprintf",
],
)
tf_cuda_library( tf_cuda_library(
name = "tf_status_internal", name = "tf_status_internal",
hdrs = [ hdrs = [

View File

@ -213,7 +213,6 @@ void TF_Reset(const TF_SessionOptions* opt, const char** containers,
namespace tensorflow { namespace tensorflow {
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out) { TF_Buffer* out) {
if (out->data != nullptr) { if (out->data != nullptr) {
@ -306,8 +305,8 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
} }
// Helpers for loading a TensorFlow plugin (a .so file). // Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result, Status LoadDynamicLibrary(const char* library_filename, void** result,
const void** buf, size_t* len); const void** buf, size_t* len);
// TODO(josh11b,mrry): Change Session to be able to use a Graph* // TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and // directly, instead of requiring us to serialize to a GraphDef and
@ -552,7 +551,7 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
TF_Library* lib_handle = new TF_Library; TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadLibrary( status->status = tensorflow::LoadDynamicLibrary(
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
&lib_handle->op_list.length); &lib_handle->op_list.length);
if (!status->status.ok()) { if (!status->status.ok()) {

View File

@ -125,6 +125,14 @@ TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*);
TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer);
// --------------------------------------------------------------------------
// Used to return strings across the C API. The caller does not take ownership
// of the underlying data pointer and is not responsible for freeing it.
typedef struct TF_StringView {
const char* data;
size_t len;
} TF_StringView;
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// TF_SessionOptions holds options that can be passed during session creation. // TF_SessionOptions holds options that can be passed during session creation.
typedef struct TF_SessionOptions TF_SessionOptions; typedef struct TF_SessionOptions TF_SessionOptions;

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
@ -525,12 +526,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr, std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr)); grpc_server->worker_env()->collective_executor_mgr.get()));
} else { } else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr)); grpc_server->worker_env()->collective_executor_mgr.get()));
} }
return tensorflow::Status::OK(); return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR #undef LOG_AND_RETURN_IF_ERROR
@ -551,6 +552,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
status->status = EnableCollectiveOps(server_def, ctx); status->status = EnableCollectiveOps(server_def, ctx);
} }
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
collective_executor_handle->get()->StartAbort(status->status);
}
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
result->num_items = num_items; result->num_items = num_items;

View File

@ -230,6 +230,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
size_t proto_len, size_t proto_len,
TF_Status* status); TF_Status* status);
// Aborts all ongoing collectives with the specified status. After abortion,
// subsequent collectives will error with this status immediately.
//
// This is intended to be used when a peer failure is detected. There's yet no
// way to reset the collectives other than restarting the program.
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status);
// Information about the shape of a Tensor and its type. // Information about the shape of a Tensor and its type.
struct TF_ShapeAndType { struct TF_ShapeAndType {
// Number of dimensions. -1 indicates unknown rank. // Number of dimensions. -1 indicates unknown rank.

View File

@ -240,6 +240,8 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/cc/profiler", "//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -260,6 +262,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
], ],
) )
@ -308,6 +311,8 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/util:abstract_stack_trace",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
@ -514,7 +519,6 @@ tf_cuda_cc_test(
extra_copts = tfe_xla_copts(), extra_copts = tfe_xla_copts(),
tags = [ tags = [
"no_windows", "no_windows",
"noasan", # leaks gRPC server instances
], ],
deps = [ deps = [
":c_api", ":c_api",
@ -581,7 +585,6 @@ tf_cuda_cc_test(
extra_copts = tfe_xla_copts(), extra_copts = tfe_xla_copts(),
tags = [ tags = [
"no_windows", "no_windows",
"noasan", # leaks gRPC server instances
], ],
deps = [ deps = [
":c_api", ":c_api",

View File

@ -18,11 +18,12 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/refcount.h"
namespace tensorflow { namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate // Abstract interface to a Tensor handle in either tracing or immediate
// execution mode. // execution mode.
class AbstractTensorHandle { class AbstractTensorHandle : public core::RefCounted {
protected: protected:
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt }; enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
@ -34,14 +35,6 @@ class AbstractTensorHandle {
AbstractTensorHandleKind getKind() const { return kind_; } AbstractTensorHandleKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
private: private:
const AbstractTensorHandleKind kind_; const AbstractTensorHandleKind kind_;
}; };
@ -50,7 +43,7 @@ namespace internal {
struct AbstractTensorHandleDeleter { struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const { void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) { if (p != nullptr) {
p->Release(); p->Unref();
} }
} }
}; };

View File

@ -94,7 +94,6 @@ limitations under the License.
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
using tensorflow::int64;
using tensorflow::string; using tensorflow::string;
namespace { namespace {
@ -725,13 +724,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) { if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE
tfrt::SmallVector<std::string, 4> op_handler_chains; return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
status->status = tfrt::ListOpHandlerChains(
opts->session_options.options, &op_handler_chains, &device_attributes);
if (!status->status.ok()) return nullptr;
return tensorflow::wrap(new tfrt::ContextInterface(
op_handler_chains, device_attributes, opts->async));
#else #else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr; return nullptr;
@ -974,7 +967,7 @@ int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
return -1; return -1;
} }
int64 num_elements = -1; tensorflow::int64 num_elements = -1;
status->status = tensorflow::unwrap(h)->NumElements(&num_elements); status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
return num_elements; return num_elements;
} }
@ -986,7 +979,7 @@ int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
return -1; return -1;
} }
int64 dim = -1; tensorflow::int64 dim = -1;
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim); status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
return dim; return dim;
} }
@ -1079,11 +1072,13 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
status->status = context->FindDeviceFromName(device_name, &device); status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr; tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) { if (!status->status.ok()) {
status->status = if (!context->FindCustomDeviceFromName(device_name, &custom_device)) {
context->FindCustomDeviceFromName(device_name, &custom_device);
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg); deallocator(data, len, deallocator_arg);
status->status =
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
return nullptr; return nullptr;
} else {
status->status = tensorflow::Status::OK();
} }
} }
std::vector<tensorflow::int64> dimvec(num_dims); std::vector<tensorflow::int64> dimvec(num_dims);

View File

@ -26,14 +26,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#endif // TENSORFLOW_EAGER_USE_XLA #endif // TENSORFLOW_EAGER_USE_XLA
using tensorflow::int64;
using tensorflow::string; using tensorflow::string;
namespace { namespace {
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle, std::vector<tensorflow::int64> TensorShapeAsVector(
tensorflow::Status* status) { const tensorflow::TensorHandle& handle, tensorflow::Status* status) {
std::vector<int64> shape; std::vector<tensorflow::int64> shape;
int rank = -1; int rank = -1;
*status = handle.NumDims(&rank); *status = handle.NumDims(&rank);
if (!status->ok()) { if (!status->ok()) {
@ -79,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr; return nullptr;
} }
if (VLOG_IS_ON(3)) { if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = std::vector<tensorflow::int64> shape_to_log =
TensorShapeAsVector(*handle, &status->status); TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) { if (!status->status.ok()) {
// Ignore the status here as we are simply logging. // Ignore the status here as we are simply logging.
@ -128,14 +127,14 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
} }
int rank = padded_shape.dimensions_size(); int rank = padded_shape.dimensions_size();
std::vector<int64> dev_dims; std::vector<tensorflow::int64> dev_dims;
dev_dims.reserve(rank); dev_dims.reserve(rank);
if (rank == 1) { if (rank == 1) {
// Rank 1 tensors might not have padded_shape.layout.minor_to_major set, // Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
dev_dims.push_back(padded_shape.dimensions(0)); dev_dims.push_back(padded_shape.dimensions(0));
} else { } else {
for (int i = rank - 1; i >= 0; --i) { for (int i = rank - 1; i >= 0; --i) {
int64 dim_index = padded_shape.layout().minor_to_major(i); tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i);
dev_dims.push_back(padded_shape.dimensions(dim_index)); dev_dims.push_back(padded_shape.dimensions(dim_index));
} }
} }
@ -146,7 +145,8 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// If the tensor is not an XLA tensor, the device shape is // If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape. // the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status); std::vector<tensorflow::int64> dev_dims =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <regex> // NOLINT
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
@ -174,9 +176,9 @@ void TestFunctionWithPackedInput(const bool remote) {
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task. // Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name); TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task1_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name); TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name); TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name);
// Add a sync point in order to make sure that variables have been initialized // Add a sync point in order to make sure that variables have been initialized
// before the function execution starts. // before the function execution starts.
@ -185,6 +187,9 @@ void TestFunctionWithPackedInput(const bool remote) {
VarIsInitialized(ctx, h2); VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle. // Pack 3 variable handles into one TFE_TensorHandle.
// When remote is false, function device is placed on task0. Handle types are
// REMOTE, REMOTE, LOCAL on task0. When remote is true, function device is
// placed on task1, Handle types are LOCAL, REMOTE, LOCAL on task1.
int num_replicas = 3; int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2}; std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle = TFE_TensorHandle* packed_handle =
@ -259,61 +264,64 @@ TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true); TestFunctionWithPackedInput(/*remote=*/true);
} }
string VariableAddFunctionSignature() {
return " signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }";
}
string VariableAddFunction() { string VariableAddFunction() {
tensorflow::FunctionDef def; tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString( CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {" VariableAddFunctionSignature(), &def));
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString(); return def.SerializeAsString();
} }
@ -425,6 +433,17 @@ TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) {
GraphErrorInjectionPass::enabled_ = false; GraphErrorInjectionPass::enabled_ = false;
} }
string VariableAddFunctionWithGraphError() {
string signature = VariableAddFunctionSignature();
// Replace the node 'read0' with 'read0_maybe_with_graph_error', so that the
// error injecting pass can identify and introduce graph pass errors.
signature = std::regex_replace(signature, std::regex("read0"),
"read0_maybe_with_graph_error");
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(signature, &def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public: public:
FunctionErrorInjectionPass(string error_node, string error_device) FunctionErrorInjectionPass(string error_node, string error_device)
@ -471,16 +490,19 @@ void TestDistributedFunctionCancellation(bool inject_error) {
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) { if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op // Inject a function optimization pass failure when it sees the
// having a requested device `dev2_name`. During execution: // 'read0_maybe_with_graph_error' op having a requested device `dev2_name`.
// * task:0 processes the main function `VariableAddFunction` and places // During execution:
// the read0 op on task:2 // * task:0 processes main function `VariableAddFunctionWithGraphError`
// * task:0 partitions the main function with a subgraph containing read0 // and places the 'read0_maybe_with_graph_error' op on task:2
// sent to task:2 // * task:0 partitions the main function with a subgraph containing
// * task:2 graph pass reports an error when it sees read0 with dev2_name // 'read0_maybe_with_graph_error' sent to task:2
// * task:2 graph pass reports an error when it sees
// 'read0_maybe_with_graph_error' with dev2_name
tensorflow::function_optimization_registration:: tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass( FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name)); std::make_unique<FunctionErrorInjectionPass>(
"read0_maybe_with_graph_error", dev2_name));
} }
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
@ -496,7 +518,7 @@ void TestDistributedFunctionCancellation(bool inject_error) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr); EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction(); const string function_def = VariableAddFunctionWithGraphError();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status); status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
@ -115,40 +116,42 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
string MatMulFunction() { string MatMulFunction(const string& matmul_device) {
tensorflow::FunctionDef def; tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString( CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {" absl::StrCat(" signature {"
" name: 'MatMulFunction'" " name: 'MatMulFunction'"
" input_arg {" " input_arg {"
" name: 'a'" " name: 'a'"
" type: DT_FLOAT" " type: DT_FLOAT"
" }" " }"
" input_arg {" " input_arg {"
" name: 'b'" " name: 'b'"
" type: DT_FLOAT" " type: DT_FLOAT"
" }" " }"
" output_arg {" " output_arg {"
" name: 'm'" " name: 'm'"
" type: DT_FLOAT" " type: DT_FLOAT"
" }" " }"
" }" " }"
" node_def {" " node_def {"
" name: 'matmul'" " name: 'matmul'"
" op: 'MatMul'" " op: 'MatMul'"
" input: 'a'" " input: 'a'"
" input: 'b'" " input: 'b'"
" attr {" " device: '",
" key: 'T'" matmul_device, "'",
" value {" " attr {"
" type: DT_FLOAT" " key: 'T'"
" }" " value {"
" }" " type: DT_FLOAT"
" }" " }"
" ret {" " }"
" key: 'm'" " }"
" value: 'matmul:product'" " ret {"
" }", " key: 'm'"
" value: 'matmul:product'"
" }"),
&def)); &def));
return def.SerializeAsString(); return def.SerializeAsString();
} }
@ -157,7 +160,8 @@ string MatMulFunction() {
// which creates a remote remote input, to simulate a scenario that the remote // which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function. // input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc) { bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false) {
tensorflow::ServerDef server_def = GetServerDef(3); tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0. // This server def has the task index set to 0.
@ -214,7 +218,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
TFE_Op* matmul = nullptr; TFE_Op* matmul = nullptr;
if (func) { if (func) {
string function_def = MatMulFunction(); const string matmul_device = remote_func_outputs ? task2_name : "";
string function_def = MatMulFunction(matmul_device);
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status); status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
@ -250,7 +255,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors // TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) { if (!remote && !async && !remote_func_outputs) {
auto remote_arg = auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored. // The input handles should never change since they have been mirrored.
@ -329,6 +334,19 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false); /*heavy_load_on_streaming_rpc=*/false);
} }
// TODO(b/162618595): Enable this test once we remove the check of remote
// outputs in ProcessFunctionLibraryRuntime.
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that // A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready. // the function execution should wait until the remote input is ready.

View File

@ -88,6 +88,20 @@ TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
return th; return th;
} }
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) { TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
constexpr int64_t dims[] = {100, 100}; constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1]; constexpr int num_elements = dims[0] * dims[1];
@ -143,7 +157,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status); TFE_OpSetAttrShape(op, "shape", {}, 0, status);
TFE_OpSetAttrString(op, "container", "", 0); TFE_OpSetAttrString(op, "container", "localhost", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0); TFE_OpSetAttrString(op, "shared_name", "", 0);
if (!device_name.empty()) { if (!device_name.empty()) {
TFE_OpSetDevice(op, device_name.c_str(), status); TFE_OpSetDevice(op, device_name.c_str(), status);

View File

@ -34,6 +34,12 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing a 2x2 matrix of floats // Return a tensor handle containing a 2x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx); TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing 2D matrix containing given data and
// dimensions
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims);
// Return a tensor handle containing a 100x100 matrix of floats // Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx); TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);

View File

@ -147,7 +147,7 @@ TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); } void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); } void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); } TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); } void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }

View File

@ -33,6 +33,7 @@ limitations under the License.
using tensorflow::dyn_cast; using tensorflow::dyn_cast;
using tensorflow::string; using tensorflow::string;
using tensorflow::gtl::ArraySlice;
namespace tensorflow { namespace tensorflow {
namespace tracing { namespace tracing {
@ -48,7 +49,6 @@ class GraphTensor : public TracingTensorHandle {
public: public:
explicit GraphTensor(TF_Output output) explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {} : TracingTensorHandle(kGraph), output_(output) {}
void Release() override { delete this; }
tensorflow::DataType DataType() const override { tensorflow::DataType DataType() const override {
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_)); return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
@ -138,20 +138,23 @@ class GraphOperation : public TracingOperation {
Status SetAttrString(const char* attr_name, const char* data, Status SetAttrString(const char* attr_name, const char* data,
size_t length) override { size_t length) override {
return tensorflow::errors::Unimplemented( tensorflow::StringPiece s(data, length);
"SetAttrString has not been implemented yet."); op_->node_builder.Attr(attr_name, s);
return Status::OK();
} }
Status SetAttrInt(const char* attr_name, int64_t value) override { Status SetAttrInt(const char* attr_name, int64_t value) override {
return tensorflow::errors::Unimplemented( static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"SetAttrInt has not been implemented yet."); "64-bit int types should match in size");
op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
return Status::OK();
} }
Status SetAttrFloat(const char* attr_name, float value) override { Status SetAttrFloat(const char* attr_name, float value) override {
return tensorflow::errors::Unimplemented( op_->node_builder.Attr(attr_name, value);
"SetAttrFloat has not been implemented yet."); return Status::OK();
} }
Status SetAttrBool(const char* attr_name, bool value) override { Status SetAttrBool(const char* attr_name, bool value) override {
return tensorflow::errors::Unimplemented( op_->node_builder.Attr(attr_name, value);
"SetAttrBool has not been implemented yet."); return Status::OK();
} }
Status SetAttrType(const char* const attr_name, DataType value) override { Status SetAttrType(const char* const attr_name, DataType value) override {
if (!op_) { if (!op_) {
@ -164,8 +167,15 @@ class GraphOperation : public TracingOperation {
} }
Status SetAttrShape(const char* attr_name, const int64_t* dims, Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override { const int num_dims) override {
return tensorflow::errors::Unimplemented( PartialTensorShape shape;
"SetAttrShape has not been implemented yet."); if (num_dims >= 0) {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
}
op_->node_builder.Attr(attr_name, shape);
return Status::OK();
} }
Status SetAttrFunction(const char* attr_name, Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override { const AbstractOperation* value) override {
@ -174,8 +184,10 @@ class GraphOperation : public TracingOperation {
} }
Status SetAttrFunctionName(const char* attr_name, const char* value, Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override { size_t length) override {
return tensorflow::errors::Unimplemented( tensorflow::NameAttrList func_name;
"SetAttrFunctionName has not been implemented yet."); func_name.set_name(string(value, value + length));
op_->node_builder.Attr(attr_name, func_name);
return Status::OK();
} }
Status SetAttrTensor(const char* attr_name, Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override { AbstractTensorInterface* tensor) override {
@ -184,33 +196,71 @@ class GraphOperation : public TracingOperation {
} }
Status SetAttrStringList(const char* attr_name, const void* const* values, Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override { const size_t* lengths, int num_values) override {
return tensorflow::errors::Unimplemented( if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
"SetAttrStringList has not been implemented yet."); op_->colocation_constraints.clear();
for (int i = 0; i < num_values; ++i) {
op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
lengths[i]);
}
} else {
std::vector<tensorflow::StringPiece> v;
v.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
}
op_->node_builder.Attr(attr_name, v);
}
return Status::OK();
} }
Status SetAttrFloatList(const char* attr_name, const float* values, Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override { int num_values) override {
return tensorflow::errors::Unimplemented( op_->node_builder.Attr(attr_name,
"SetAttrFloatList has not been implemented yet."); ArraySlice<const float>(values, num_values));
return Status::OK();
} }
Status SetAttrIntList(const char* attr_name, const int64_t* values, Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override { int num_values) override {
return tensorflow::errors::Unimplemented( static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"SetAttrIntList has not been implemented yet."); "64-bit int types should match in size");
op_->node_builder.Attr(
attr_name,
ArraySlice<const tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(values), num_values));
return Status::OK();
} }
Status SetAttrTypeList(const char* attr_name, const DataType* values, Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override { int num_values) override {
return tensorflow::errors::Unimplemented( op_->node_builder.Attr(attr_name,
"SetAttrTypeList has not been implemented yet."); ArraySlice<const DataType>(values, num_values));
return Status::OK();
} }
Status SetAttrBoolList(const char* attr_name, const unsigned char* values, Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override { int num_values) override {
return tensorflow::errors::Unimplemented( std::unique_ptr<bool[]> b(new bool[num_values]);
"SetAttrBoolList has not been implemented yet."); for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
op_->node_builder.Attr(attr_name,
ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
} }
Status SetAttrShapeList(const char* attr_name, const int64_t** dims, Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override { const int* num_dims, int num_values) override {
return tensorflow::errors::Unimplemented( std::vector<PartialTensorShape> shapes;
"SetAttrShapeList has not been implemented yet."); shapes.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
if (num_dims[i] < 0) {
shapes.emplace_back();
} else {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shapes.emplace_back(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
}
}
op_->node_builder.Attr(attr_name, shapes);
return Status::OK();
} }
Status SetAttrFunctionList( Status SetAttrFunctionList(
const char* attr_name, const char* attr_name,

View File

@ -92,9 +92,255 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
TF_DeleteExecutionContext(ctx); TF_DeleteExecutionContext(ctx);
} }
// MatMul Test
TEST_P(UnifiedCAPI, TestBasicEagerMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Want to test simple MatMul example:
[[0,0], * [[0,0], = [[0,0],
[0,0]] [0,0]] [0,0]]
*/
// Build an abstract input tensor.
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
float vals[] = {0.0f, 0.0f, 0.0f, 0.0f};
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
TFE_TensorHandle* t =
TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims);
TF_AbstractTensor* at = TF_CreateAbstractTensorFromEagerTensor(
t, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "MatMul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
// Copy Tensor data into an array.
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(result_tensor),
TF_TensorByteSize(result_tensor));
int data_len = 4; // length of result_data
for (int i = 0; i < data_len; i++) {
EXPECT_EQ(result_data[i], 0);
}
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
}
// MatMul Test 2
TEST_P(UnifiedCAPI, TestBasicEagerMatMul2) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Want to test simple MatMul example with abstract tensors:
[[1,2], * [[5,6], = [[19,22],
[3,4]] [7,8]] [43,50]]
*/
// Build 1st Matrix.
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f};
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
TFE_TensorHandle* t1 =
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor(
t1, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build 2nd Matrix.
float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f};
TFE_TensorHandle* t2 =
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor(
t2, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "MatMul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at1, at2};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at1);
TF_DeleteAbstractTensor(at2);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
// Copy Tensor data into array.
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(result_tensor),
TF_TensorByteSize(result_tensor));
// Build expected result & verify.
float e_vals[] = {19.0f, 22.0f, 43.0f, 50.0f};
int data_len = 4; // length of e_vals
for (int i = 0; i < data_len; i++) {
EXPECT_EQ(result_data[i], e_vals[i]);
}
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
}
// MatAdd
TEST_P(UnifiedCAPI, TestBasicEagerMatAdd) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Want to test simple MatAdd example with abstract tensors:
[[1,2] , + [[5,6], = [[6,8],
[3,4] ] [7,8] ] [10,12]]
*/
// Build 1st Matrix.
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f};
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
TFE_TensorHandle* t1 =
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor(
t1, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build 2nd Matrix.
float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f};
TFE_TensorHandle* t2 =
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor(
t2, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at1, at2};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at1);
TF_DeleteAbstractTensor(at2);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
// Copy Tensor data into array.
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(result_tensor),
TF_TensorByteSize(result_tensor));
// Build expected result & verify.
float e_vals[] = {6.0f, 8.0f, 10.0f, 12.0f};
int data_len = 4; // length of e_vals
for (int i = 0; i < data_len; i++) {
EXPECT_EQ(result_data[i], e_vals[i]);
}
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
}
TEST_P(UnifiedCAPI, TestBasicGraph) { TEST_P(UnifiedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
// Start a new function / execution context. // Start a new function / execution context.
string fn_name = "double"; string fn_name = "double";
TF_ExecutionContext* graph_ctx = TF_ExecutionContext* graph_ctx =
@ -142,6 +388,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get()); TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build the abstract op to run the function. // Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get()); TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
@ -180,6 +427,111 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
TF_DeleteExecutionContext(eager_execution_ctx); TF_DeleteExecutionContext(eager_execution_ctx);
} }
// Graph Tracing for MatMul
TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Start a new function / execution context.
string fn_name = "matrix_multiply";
TF_ExecutionContext* graph_ctx =
TF_CreateFunction(fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* matmul_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(matmul_op, "MatMul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(matmul_op, "my_matmul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* mm_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(matmul_op, 2, inputs, mm_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(matmul_op);
TF_AbstractFunction* func =
TF_FinalizeFunction(graph_ctx, mm_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Now that the graph is built, test graph implementation on matmul example:
[[1,1] , * [[1,1] , = [[2,2],
[1,1]] [1,1]] [2,2]]
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
float vals[] = {1.0f, 1.0f, 1.0f, 1.0f};
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
TFE_TensorHandle* input_eager =
TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims);
TF_AbstractTensor* input_t =
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, mm_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(mm_outputs));
TF_AbstractTensor* final_result = TF_OutputListGet(mm_outputs, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t));
int data_len = 4;
for (int i = 0; i < data_len; i++) {
ASSERT_EQ(result_data[i], 2.0f);
}
TF_DeleteAbstractTensor(final_result);
TF_DeleteOutputList(mm_outputs);
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TF_DeleteTensor(f_t);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(eager_execution_ctx);
}
TEST_P(UnifiedCAPI, TestMultiOutputGraph) { TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -336,6 +688,217 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_DeleteAbstractFunction(func); TF_DeleteAbstractFunction(func);
} }
TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Status* s = status.get();
// Start a new function / execution context.
string fn_name = "two_adds_and_matmul";
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
TF_AbstractTensor* add_output1;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output1 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Same with a second "Add" computing `arg1 + arg1`.
TF_AbstractTensor* add_output2;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output2 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// 3rd Output will be Matrix Multiplication of add_output1 and add_output2
TF_AbstractTensor* mm_output;
{
// Build an abstract operation, inputs and output.
auto* mm_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(mm_op, "MatMul", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(mm_op, "mm", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {add_output1, add_output2};
TF_OutputList* mm_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(mm_op, 2, inputs, mm_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(mm_op);
// Extract the resulting tensor.
mm_output = TF_OutputListGet(mm_outputs, 0);
TF_DeleteOutputList(mm_outputs);
}
// Finalize the function by providing the returned values.
TF_AbstractFunction* func;
{
// We want to return the output of both add operations and MatMul operation,
// create a new list and populate it.
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListPushBack(func_outputs, add_output1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, add_output2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, mm_output, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteOutputList(func_outputs);
}
/**
* We traced so far this function:
*
* def two_adds_and_mm(A, B):
* my_add1 = A + B
* my_add2 = B + B
* mm = tf.MatMul(my_add1,my_add2)
* return my_add1, my_add2, mm
*
* Now we will execute this function with an eager context:
*
* A =[[0, 1],[1, 0]]
* B =[[1, 0],[0, 1]]
*
* output1, output2, output3 = two_adds_and_mm(A, B)
*
* We expect outputs:
*
* output1 = [[1, 1],[1, 1]]
* output2 = [[2, 0],[0, 2]]
* output3 = [[2, 2],[2, 2]]
*
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build two abstract input tensors as function arguments.
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx, s);
// 1st Arg
float vals1[] = {0.0f, 1.0f, 1.0f, 0.0f};
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
TFE_TensorHandle* input_eager =
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// 2nd Arg
float vals2[] = {1.0f, 0.0f, 0.0f, 1.0f};
input_eager =
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
}
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(func_outputs, 3, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
ASSERT_EQ(3, TF_OutputListNumOutputs(func_outputs));
float expected_outputs[3][4] = {{1.0f, 1.0f, 1.0f, 1.0f},
{2.0f, 0.0f, 0.0f, 2.0f},
{2.0f, 2.0f, 2.0f, 2.0f}};
float result_data[4];
for (int idx = 0; idx < 3; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t));
// Verify results for each output
for (int j = 0; j < 4; j++) {
ASSERT_EQ(result_data[j], expected_outputs[idx][j]);
}
TF_DeleteTensor(f_t);
}
// Free memory associated with add and MatMul outputs
for (int idx = 0; idx < 3; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TF_DeleteAbstractTensor(result);
}
TF_DeleteOutputList(func_outputs);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteAbstractFunction(func);
}
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);

View File

@ -51,25 +51,14 @@ int64 ToId(AbstractTensorHandle* t) {
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx) TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
: handle_(handle), ctx_(ctx) { : handle_(handle), ctx_(ctx) {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely handle_->Ref();
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
} }
TapeTensor::TapeTensor(const TapeTensor& other) { TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_; handle_ = other.handle_;
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely handle_->Ref();
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
ctx_ = other.ctx_; ctx_ = other.ctx_;
} }
TapeTensor::~TapeTensor() { TapeTensor::~TapeTensor() { handle_->Unref(); }
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Unref();
}
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); } tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
@ -112,7 +101,7 @@ AbstractTensorHandle* TapeTensor::ZerosLike() const {
} }
if (isa<tracing::TracingOperation>(op.get())) { if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName( s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str()); absl::StrCat("ZerosLike", ToId(handle_)).c_str());
if (!s.ok()) { if (!s.ok()) {
return nullptr; return nullptr;
} }
@ -175,7 +164,8 @@ Status TapeVSpace::CallBackwardFunction(
gtl::ArraySlice<AbstractTensorHandle*> output_gradients, gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const { std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK(); if (backward_function == nullptr) return Status::OK();
return backward_function->Compute(output_gradients, result); Context ctx = {ctx_};
return backward_function->Compute(&ctx, output_gradients, result);
} }
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.
@ -191,7 +181,7 @@ TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {} void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const { void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Release(); gradient->Unref();
} }
// Helper functions which delegate to `AbstractOperation`, update // Helper functions which delegate to `AbstractOperation`, update
@ -373,6 +363,10 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
input_ids[i] = ToId(forward_op_->inputs[i]); input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType(); input_dtypes[i] = forward_op_->inputs[i]->DataType();
} }
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]);
}
std::vector<TapeTensor> tape_tensors; std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) { for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx)); tape_tensors.push_back(TapeTensor(t, ctx));

View File

@ -31,7 +31,8 @@ namespace gradients {
// //
// class AddGradientFunction : public GradientFunction { // class AddGradientFunction : public GradientFunction {
// public: // public:
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs, // Status Compute(Context* ctx,
// absl::Span<AbstractTensorHandle* const> grad_inputs,
// std::vector<AbstractTensorHandle*>* grad_outputs) override { // std::vector<AbstractTensorHandle*>* grad_outputs) override {
// grad_outputs->resize(2); // grad_outputs->resize(2);
// (*grad_outputs)[0] = grad_inputs[0]; // (*grad_outputs)[0] = grad_inputs[0];
@ -50,11 +51,16 @@ namespace gradients {
// Status RegisterGradients(GradientRegistry* registry) { // Status RegisterGradients(GradientRegistry* registry) {
// return registry->Register("Add", AddRegisterer); // return registry->Register("Add", AddRegisterer);
// } // }
struct Context {
public:
AbstractContext* ctx;
};
class GradientFunction { class GradientFunction {
public: public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in // TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`. // `grad_inputs`.
virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs, virtual Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0; std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {} virtual ~GradientFunction() {}
}; };

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
@ -42,55 +44,12 @@ class CppGradients
} }
}; };
// Creates an Identity op. Status RegisterGradients(GradientRegistry* registry) {
Status Identity(AbstractContext* ctx, TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
absl::Span<AbstractTensorHandle* const> inputs, TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
return Status::OK(); return Status::OK();
} }
// =================== Register gradients for Add ============================
class AddGradientFunction : public GradientFunction {
public:
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
private:
AbstractContext* ctx_;
};
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction(op.ctx);
}
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
}
// =================== End gradient registrations ============================
// Computes `inputs[0] + inputs[1]` and records it on the tape. // Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape, Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
@ -112,6 +71,26 @@ Status Add(AbstractContext* ctx, Tape* tape,
registry); registry);
} }
// Computes `exp(inputs[0])` and records it on the tape.
Status Exp(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr exp_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(exp_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(exp_op.get())->SetOpName("my_exp"));
}
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes // Computes
// y = inputs[0] + inputs[1] // y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]}) // return grad(y, {inputs[0], inputs[1]})
@ -136,7 +115,7 @@ Status AddGradModel(AbstractContext* ctx,
source_tensors_that_are_targets, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads)); /*output_gradients=*/{}, &out_grads));
for (auto add_output : add_outputs) { for (auto add_output : add_outputs) {
add_output->Release(); add_output->Unref();
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
outputs[1] = out_grads[1]; outputs[1] = out_grads[1];
@ -144,6 +123,35 @@ Status AddGradModel(AbstractContext* ctx,
return Status::OK(); return Status::OK();
} }
// Computes
// y = exp(inputs[0])
// return grad(y, {inputs[0]})
Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1);
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) { AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -187,14 +195,15 @@ Status RunModel(Model model, AbstractContext* ctx,
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry)); absl::MakeSpan(output_list.outputs), registry));
for (auto func_input : func_inputs) { for (auto func_input : func_inputs) {
func_input->Release(); func_input->Unref();
} }
AbstractFunction* func = nullptr; AbstractFunction* func = nullptr;
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get()) TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func)); ->Finalize(&output_list, &func));
scoped_func.reset(func); scoped_func.reset(func);
output_list.outputs[0]->Release(); for (auto output : output_list.outputs) {
output_list.outputs[1]->Release(); output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
} }
@ -295,7 +304,7 @@ TEST_P(CppGradients, TestAddGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor)); auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0); EXPECT_EQ(*result_value, 1.0);
outputs[0]->Release(); outputs[0]->Unref();
TF_DeleteTensor(result_tensor); TF_DeleteTensor(result_tensor);
result_tensor = nullptr; result_tensor = nullptr;
@ -303,17 +312,61 @@ TEST_P(CppGradients, TestAddGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor)); result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0); EXPECT_EQ(*result_value, 1.0);
outputs[1]->Release(); outputs[1]->Unref();
TF_DeleteTensor(result_tensor); TF_DeleteTensor(result_tensor);
} }
TEST_P(CppGradients, TestExpGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = exp(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 2.718, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is // TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation. // supported. It is needed for AddN op which is used for gradient aggregation.
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients, UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"), ::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false), /*tfrt*/ ::testing::Values(true, false),
/*executing_eagerly*/ ::testing::Values(true, false))); /*executing_eagerly*/ ::testing::Values(true, false)));
#else #else
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "absl/types/optional.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/util/abstract_stack_trace.h"
struct TFE_Op; struct TFE_Op;
@ -36,6 +38,10 @@ class ImmediateExecutionOperation : public AbstractOperation {
public: public:
virtual void Clear() = 0; virtual void Clear() = 0;
// Returns the inputs of this op.
virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs()
const = 0;
virtual const tensorflow::OpDef* OpDef() const = 0; virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status InputLength(const char* input_name, int* length) = 0; virtual Status InputLength(const char* input_name, int* length) = 0;
@ -44,6 +50,12 @@ class ImmediateExecutionOperation : public AbstractOperation {
// Experimental // Experimental
virtual Status SetUseXla(bool enable) = 0; virtual Status SetUseXla(bool enable) = 0;
// Set stack trace to be used for potential async error reporting.
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
// Returns the stack trace set by `SetStackTrace` if exists.
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
// For LLVM style RTTI. // For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) { static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt; return ptr->getKind() == kEager || ptr->getKind() == kTfrt;

View File

@ -50,6 +50,14 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
// Return a copy of the handle. // Return a copy of the handle.
virtual ImmediateExecutionTensorHandle* Copy() = 0; virtual ImmediateExecutionTensorHandle* Copy() = 0;
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
// For LLVM style RTTI. // For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) { static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt; return ptr->getKind() == kEager || ptr->getKind() == kTfrt;

View File

@ -177,12 +177,12 @@ class GradientTape {
template <typename Gradient> template <typename Gradient>
class ForwardFunction class ForwardFunction
: public std::function<Status(const std::vector<Gradient*>&, : public std::function<Status(const std::vector<Gradient*>&,
std::vector<Gradient*>*)> { std::vector<Gradient*>*, bool)> {
public: public:
template <typename lambda_type> template <typename lambda_type>
explicit ForwardFunction(lambda_type lambda) explicit ForwardFunction(lambda_type lambda)
: std::function<Status(const std::vector<Gradient*>&, : std::function<Status(const std::vector<Gradient*>&,
std::vector<Gradient*>*)>(lambda) {} std::vector<Gradient*>*, bool)>(lambda) {}
}; };
// Computes Jacobian-vector products using forward-mode automatic // Computes Jacobian-vector products using forward-mode automatic
@ -205,8 +205,9 @@ class ForwardAccumulator {
// Does not take ownership of `vspace`, which must outlive the // Does not take ownership of `vspace`, which must outlive the
// ForwardAccumulator. // ForwardAccumulator.
explicit ForwardAccumulator( explicit ForwardAccumulator(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace) const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
: vspace_(vspace) { bool use_batch)
: vspace_(vspace), use_batch_(use_batch) {
call_state_.emplace(nullptr, false); call_state_.emplace(nullptr, false);
} }
@ -314,6 +315,9 @@ class ForwardAccumulator {
// available in language bindings (e.g. Python). // available in language bindings (e.g. Python).
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_; const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
// Decides if tangents are vectorized or not
bool use_batch_;
struct AccumulatorCallState { struct AccumulatorCallState {
AccumulatorCallState( AccumulatorCallState(
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape, GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
@ -573,7 +577,7 @@ Status InitialGradients(
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape, gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
const OpTape<BackwardFunction, TapeTensor>& op_tape, const OpTape<BackwardFunction, TapeTensor>& op_tape,
std::unordered_map<int64, std::vector<Gradient*>>* result) { std::unordered_map<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) { for (int i = 0, end = target_tensor_ids.size(); i < end; ++i) {
const int64 id = target_tensor_ids[i]; const int64 id = target_tensor_ids[i];
if (output_gradients.empty() || output_gradients[i] == nullptr) { if (output_gradients.empty() || output_gradients[i] == nullptr) {
auto tensor_it = tensor_tape.find(id); auto tensor_it = tensor_tape.find(id);
@ -699,7 +703,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
std::vector<Gradient*> out_gradients; std::vector<Gradient*> out_gradients;
out_gradients.reserve(trace.output_tensor_info.size()); out_gradients.reserve(trace.output_tensor_info.size());
std::vector<int64> unneeded_gradients; std::vector<int64> unneeded_gradients;
for (int i = 0; i < trace.input_tensor_id.size(); i++) { for (int i = 0, end = trace.input_tensor_id.size(); i < end; i++) {
const auto& in_tensor_id = trace.input_tensor_id[i]; const auto& in_tensor_id = trace.input_tensor_id[i];
if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() && if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() &&
sources_set.find(in_tensor_id) == sources_set.end()) { sources_set.find(in_tensor_id) == sources_set.end()) {
@ -709,7 +713,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
bool any_gradient_nonzero = false; bool any_gradient_nonzero = false;
std::vector<int> zero_indices; std::vector<int> zero_indices;
for (int i = 0; i < trace.output_tensor_info.size(); ++i) { for (int i = 0, end = trace.output_tensor_info.size(); i < end; ++i) {
const int64 id = trace.output_tensor_info[i].GetID(); const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id); auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) { if (grad_it == gradients.end()) {
@ -775,7 +779,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
} }
VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
<< trace.input_tensor_id.size() << " sources"; << trace.input_tensor_id.size() << " sources";
for (int i = 0; i < in_gradients.size(); ++i) { for (int i = 0, end = in_gradients.size(); i < end; ++i) {
const int64 id = trace.input_tensor_id[i]; const int64 id = trace.input_tensor_id[i];
if (in_gradients[i] != nullptr) { if (in_gradients[i] != nullptr) {
auto& unaggregated_grads = gradients[id]; auto& unaggregated_grads = gradients[id];
@ -968,7 +972,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
targets.reserve(grad.size()); targets.reserve(grad.size());
used_in_grads.reserve(grad.size()); used_in_grads.reserve(grad.size());
std::unordered_map<int64, TapeTensor> sources_that_are_targets; std::unordered_map<int64, TapeTensor> sources_that_are_targets;
for (int grad_index = 0; grad_index < grad.size(); ++grad_index) { for (int grad_index = 0, end = grad.size(); grad_index < end; ++grad_index) {
Gradient* grad_tensor = grad[grad_index]; Gradient* grad_tensor = grad[grad_index];
if (grad_tensor != nullptr) { if (grad_tensor != nullptr) {
int64 tensor_id = vspace_.TensorId(grad_tensor); int64 tensor_id = vspace_.TensorId(grad_tensor);
@ -1062,7 +1066,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
output_tensors, backward_function_getter, backward_function_deleter, output_tensors, backward_function_getter, backward_function_deleter,
in_grads, &forward_grads)); in_grads, &forward_grads));
} else { } else {
TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads)); TF_RETURN_IF_ERROR(
(*forward_function)(in_grads, &forward_grads, use_batch_));
} }
for (int i = 0; i < forward_grads.size(); ++i) { for (int i = 0; i < forward_grads.size(); ++i) {
if (forward_grads[i] != nullptr) { if (forward_grads[i] != nullptr) {

View File

@ -186,3 +186,22 @@ void TF_JoinThread(TF_Thread* thread) {
// ::tensorflow::Thread joins on destruction // ::tensorflow::Thread joins on destruction
delete reinterpret_cast<::tensorflow::Thread*>(thread); delete reinterpret_cast<::tensorflow::Thread*>(thread);
} }
void* TF_LoadSharedLibrary(const char* library_filename, TF_Status* status) {
void* handle = nullptr;
TF_SetStatus(status, TF_OK, "");
::tensorflow::Set_TF_Status_from_Status(
status, ::tensorflow::Env::Default()->LoadDynamicLibrary(library_filename,
&handle));
return handle;
}
void* TF_GetSymbolFromLibrary(void* handle, const char* symbol_name,
TF_Status* status) {
void* symbol = nullptr;
TF_SetStatus(status, TF_OK, "");
::tensorflow::Set_TF_Status_from_Status(
status, ::tensorflow::Env::Default()->GetSymbolFromLibrary(
handle, symbol_name, &symbol));
return symbol;
}

View File

@ -184,6 +184,26 @@ TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
// Waits for the given thread to finish execution, then deletes it. // Waits for the given thread to finish execution, then deletes it.
TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread);
// \brief Load a dynamic library.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, place OK in status and return the newly created library handle.
// Otherwise returns nullptr and set error status.
TF_CAPI_EXPORT extern void* TF_LoadSharedLibrary(const char* library_filename,
TF_Status* status);
// \brief Get a pointer to a symbol from a dynamic library.
//
// "handle" should be a pointer returned from a previous call to
// TF_LoadLibraryFromEnv. On success, place OK in status and return a pointer to
// the located symbol. Otherwise returns nullptr and set error status.
TF_CAPI_EXPORT extern void* TF_GetSymbolFromLibrary(void* handle,
const char* symbol_name,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -1,124 +0,0 @@
# Description:
# Experimental C APIs for TensorFlow.
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
tf_cuda_library(
name = "rendezvous_internal",
srcs = [
"rendezvous.cc",
],
hdrs = [
"rendezvous.h",
"rendezvous_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "rendezvous",
hdrs = [
"rendezvous.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api",
],
)
tf_cuda_library(
name = "network_internal",
srcs = [
"network.cc",
],
hdrs = [
"network.h",
"network_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "network",
hdrs = [
"network.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":network_internal",
":rendezvous",
"//tensorflow/c:c_api",
],
)
# -----------------------------------------------------------------------------
# Tests
tf_cuda_cc_test(
name = "network_test",
size = "medium",
srcs = ["network_test.cc"],
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":network",
":network_internal",
":rendezvous",
":rendezvous_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:env",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session",
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

View File

@ -78,6 +78,11 @@ typedef struct TF_Filesystem {
void* plugin_filesystem; void* plugin_filesystem;
} TF_Filesystem; } TF_Filesystem;
typedef struct TF_TransactionToken {
void* token;
TF_Filesystem* owner;
} TF_TransactionToken;
/// SECTION 2. Function tables for functionality provided by plugins /// SECTION 2. Function tables for functionality provided by plugins
/// ---------------------------------------------------------------------------- /// ----------------------------------------------------------------------------
/// ///
@ -679,6 +684,133 @@ typedef struct TF_FilesystemOps {
/// ///
/// DEFAULT IMPLEMENTATION: No op. /// DEFAULT IMPLEMENTATION: No op.
void (*flush_caches)(const TF_Filesystem* filesystem); void (*flush_caches)(const TF_Filesystem* filesystem);
/// Starts a new transaction.
///
/// An opaque transaction token is returned in `token`. Ownership of the token
/// is in filesystem. Token will be freed in `end_transaction` call and any
/// access to token after that is invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if transaction successfuly started.
/// * Must set `status` to `TF_FAILED_PRECONDITION` if multiple transactions
/// are not supported
/// * Might use any other error value for `status` to signal other errors.
int (*start_transaction)(const TF_Filesystem* filesystem,
TF_TransactionToken** token, TF_Status* status);
/// Ends transaction and free the `token`. Any access to token after
/// that will be invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if transaction successfuly finalized.
/// * Must set `status` to `TF_NOT_FOUND` if token is invalid/not found
/// * Might use any other error value for `status` to signal other errors.
int (*end_transaction)(const TF_Filesystem* filesystem,
TF_TransactionToken* token, TF_Status* status);
/// Adds file/directory in the `path` to transaction in `token`. It is a valid
/// operation to add a path that doesn't exist yet to a transaction.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if path added to transaction successful.
/// * Must set `status` to `TF_NOT_FOUND` if `token` is invalid.
/// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is in
/// another transaction and multiple transactions are not supported
/// * Might use any other error value for `status` to signal other errors.
int (*add_to_transaction)(const TF_Filesystem* filesystem, const char* path,
TF_TransactionToken* token, TF_Status* status);
/// Returns transaction token for file/directory in the `path`. Note that path
/// may not exist yet but still might be part of a transaction.
///
/// Transaction token is returned in `token`. Ownership of the token is in
/// filesystem. Token will be freed in `end_transaction` call and any access
/// to token after that is invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if a transaction for path is found
/// * Must set `status` to `TF_NOT_FOUND` if `path` is not part of any
/// transaction
/// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is
/// not in this filesystem.
/// * Might use any other error value for `status` to signal other errors.
int (*get_transaction_for_path)(const TF_Filesystem* filesystem,
const char* path, TF_TransactionToken** token,
TF_Status* status);
/// Returns transaction token for `path` if it is part of a transaction else
/// starts a new transaction and adds `path` to that transaction
///
/// Transaction token is returned in `token`. Ownership of the token is in
/// filesystem. Token will be freed in `end_transaction` call and any access
/// to token after that is invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if transaction found or successfuly
/// started.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to this
/// filesystem
/// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is
/// not in any transaction and multiple transactions are not supported.
/// * Might use any other error value for `status` to signal other errors.
int (*get_or_start_transaction_for_path)(const TF_Filesystem* filesystem,
const char* path,
TF_TransactionToken** token,
TF_Status* status);
/// Decodes transaction token in `token` to human readable format for
/// debugging.
///
/// A new `char*` buffer must be allocated by this method. Core TensorFlow
/// manages the lifetime of the buffer after the call. Thus, all callers of
/// this method must take ownership of the returned pointer.
///
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// DEFAULT IMPLEMENTATION: Dump token and owner address.
char* (*decode_transaction_token)(const TF_Filesystem* filesystem,
const TF_TransactionToken* token);
} TF_FilesystemOps; } TF_FilesystemOps;
// LINT.ThenChange(:filesystem_ops_version) // LINT.ThenChange(:filesystem_ops_version)

View File

@ -35,7 +35,8 @@ using UniquePtrTo_TF_Status =
::std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>; ::std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
Status ModularFileSystem::NewRandomAccessFile( Status ModularFileSystem::NewRandomAccessFile(
const std::string& fname, std::unique_ptr<RandomAccessFile>* result) { const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) {
if (ops_->new_random_access_file == nullptr) if (ops_->new_random_access_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewRandomAccessFile()")); "Filesystem for ", fname, " does not support NewRandomAccessFile()"));
@ -54,7 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile(
} }
Status ModularFileSystem::NewWritableFile( Status ModularFileSystem::NewWritableFile(
const std::string& fname, std::unique_ptr<WritableFile>* result) { const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_writable_file == nullptr) if (ops_->new_writable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewWritableFile()")); "Filesystem for ", fname, " does not support NewWritableFile()"));
@ -73,7 +75,8 @@ Status ModularFileSystem::NewWritableFile(
} }
Status ModularFileSystem::NewAppendableFile( Status ModularFileSystem::NewAppendableFile(
const std::string& fname, std::unique_ptr<WritableFile>* result) { const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_appendable_file == nullptr) if (ops_->new_appendable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewAppendableFile()")); "Filesystem for ", fname, " does not support NewAppendableFile()"));
@ -92,7 +95,8 @@ Status ModularFileSystem::NewAppendableFile(
} }
Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile( Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) { const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) {
if (ops_->new_read_only_memory_region_from_file == nullptr) if (ops_->new_read_only_memory_region_from_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, "Filesystem for ", fname,
@ -112,7 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status ModularFileSystem::FileExists(const std::string& fname) { Status ModularFileSystem::FileExists(const std::string& fname,
TransactionToken* token) {
if (ops_->path_exists == nullptr) if (ops_->path_exists == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support FileExists()")); "Filesystem for ", fname, " does not support FileExists()"));
@ -125,6 +130,7 @@ Status ModularFileSystem::FileExists(const std::string& fname) {
} }
bool ModularFileSystem::FilesExist(const std::vector<std::string>& files, bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
TransactionToken* token,
std::vector<Status>* status) { std::vector<Status>* status) {
if (ops_->paths_exist == nullptr) if (ops_->paths_exist == nullptr)
return FileSystem::FilesExist(files, status); return FileSystem::FilesExist(files, status);
@ -157,6 +163,7 @@ bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
} }
Status ModularFileSystem::GetChildren(const std::string& dir, Status ModularFileSystem::GetChildren(const std::string& dir,
TransactionToken* token,
std::vector<std::string>* result) { std::vector<std::string>* result) {
if (ops_->get_children == nullptr) if (ops_->get_children == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
@ -182,6 +189,7 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
} }
Status ModularFileSystem::GetMatchingPaths(const std::string& pattern, Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
TransactionToken* token,
std::vector<std::string>* result) { std::vector<std::string>* result) {
if (ops_->get_matching_paths == nullptr) if (ops_->get_matching_paths == nullptr)
return internal::GetMatchingPaths(this, Env::Default(), pattern, result); return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
@ -203,7 +211,8 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status ModularFileSystem::DeleteFile(const std::string& fname) { Status ModularFileSystem::DeleteFile(const std::string& fname,
TransactionToken* token) {
if (ops_->delete_file == nullptr) if (ops_->delete_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support DeleteFile()")); "Filesystem for ", fname, " does not support DeleteFile()"));
@ -216,6 +225,7 @@ Status ModularFileSystem::DeleteFile(const std::string& fname) {
} }
Status ModularFileSystem::DeleteRecursively(const std::string& dirname, Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
TransactionToken* token,
int64* undeleted_files, int64* undeleted_files,
int64* undeleted_dirs) { int64* undeleted_dirs) {
if (undeleted_files == nullptr || undeleted_dirs == nullptr) if (undeleted_files == nullptr || undeleted_dirs == nullptr)
@ -238,7 +248,8 @@ Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status ModularFileSystem::DeleteDir(const std::string& dirname) { Status ModularFileSystem::DeleteDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->delete_dir == nullptr) if (ops_->delete_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support DeleteDir()")); "Filesystem for ", dirname, " does not support DeleteDir()"));
@ -250,7 +261,8 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname) {
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) { Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->recursively_create_dir == nullptr) if (ops_->recursively_create_dir == nullptr)
return FileSystem::RecursivelyCreateDir(dirname); return FileSystem::RecursivelyCreateDir(dirname);
@ -261,7 +273,8 @@ Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) {
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status ModularFileSystem::CreateDir(const std::string& dirname) { Status ModularFileSystem::CreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->create_dir == nullptr) if (ops_->create_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support CreateDir()")); "Filesystem for ", dirname, " does not support CreateDir()"));
@ -273,7 +286,8 @@ Status ModularFileSystem::CreateDir(const std::string& dirname) {
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) { Status ModularFileSystem::Stat(const std::string& fname,
TransactionToken* token, FileStatistics* stat) {
if (ops_->stat == nullptr) if (ops_->stat == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat( return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support Stat()")); "Filesystem for ", fname, " does not support Stat()"));
@ -296,7 +310,8 @@ Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) {
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status ModularFileSystem::IsDirectory(const std::string& name) { Status ModularFileSystem::IsDirectory(const std::string& name,
TransactionToken* token) {
if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name); if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -307,6 +322,7 @@ Status ModularFileSystem::IsDirectory(const std::string& name) {
} }
Status ModularFileSystem::GetFileSize(const std::string& fname, Status ModularFileSystem::GetFileSize(const std::string& fname,
TransactionToken* token,
uint64* file_size) { uint64* file_size) {
if (ops_->get_file_size == nullptr) { if (ops_->get_file_size == nullptr) {
FileStatistics stat; FileStatistics stat;
@ -327,7 +343,8 @@ Status ModularFileSystem::GetFileSize(const std::string& fname,
} }
Status ModularFileSystem::RenameFile(const std::string& src, Status ModularFileSystem::RenameFile(const std::string& src,
const std::string& target) { const std::string& target,
TransactionToken* token) {
if (ops_->rename_file == nullptr) { if (ops_->rename_file == nullptr) {
Status status = CopyFile(src, target); Status status = CopyFile(src, target);
if (status.ok()) status = DeleteFile(src); if (status.ok()) status = DeleteFile(src);
@ -343,7 +360,8 @@ Status ModularFileSystem::RenameFile(const std::string& src,
} }
Status ModularFileSystem::CopyFile(const std::string& src, Status ModularFileSystem::CopyFile(const std::string& src,
const std::string& target) { const std::string& target,
TransactionToken* token) {
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target); if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -366,7 +384,7 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
return ret; return ret;
} }
void ModularFileSystem::FlushCaches() { void ModularFileSystem::FlushCaches(TransactionToken* token) {
if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get()); if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get());
} }
@ -443,7 +461,7 @@ Status RegisterFilesystemPlugin(const std::string& dso_path) {
// Step 1: Load plugin // Step 1: Load plugin
Env* env = Env::Default(); Env* env = Env::Default();
void* dso_handle; void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle)); TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin` // Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol; void* dso_symbol;

View File

@ -59,36 +59,48 @@ class ModularFileSystem final : public FileSystem {
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); } ~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT;
Status NewRandomAccessFile( Status NewRandomAccessFile(
const std::string& fname, const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) override; std::unique_ptr<RandomAccessFile>* result) override;
Status NewWritableFile(const std::string& fname, Status NewWritableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override; std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const std::string& fname, Status NewAppendableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override; std::unique_ptr<WritableFile>* result) override;
Status NewReadOnlyMemoryRegionFromFile( Status NewReadOnlyMemoryRegionFromFile(
const std::string& fname, const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override; std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
Status FileExists(const std::string& fname) override; Status FileExists(const std::string& fname, TransactionToken* token) override;
bool FilesExist(const std::vector<std::string>& files, bool FilesExist(const std::vector<std::string>& files,
TransactionToken* token,
std::vector<Status>* status) override; std::vector<Status>* status) override;
Status GetChildren(const std::string& dir, Status GetChildren(const std::string& dir, TransactionToken* token,
std::vector<std::string>* result) override; std::vector<std::string>* result) override;
Status GetMatchingPaths(const std::string& pattern, Status GetMatchingPaths(const std::string& pattern, TransactionToken* token,
std::vector<std::string>* results) override; std::vector<std::string>* results) override;
Status DeleteFile(const std::string& fname) override; Status DeleteFile(const std::string& fname, TransactionToken* token) override;
Status DeleteRecursively(const std::string& dirname, int64* undeleted_files, Status DeleteRecursively(const std::string& dirname, TransactionToken* token,
int64* undeleted_files,
int64* undeleted_dirs) override; int64* undeleted_dirs) override;
Status DeleteDir(const std::string& dirname) override; Status DeleteDir(const std::string& dirname,
Status RecursivelyCreateDir(const std::string& dirname) override; TransactionToken* token) override;
Status CreateDir(const std::string& dirname) override; Status RecursivelyCreateDir(const std::string& dirname,
Status Stat(const std::string& fname, FileStatistics* stat) override; TransactionToken* token) override;
Status IsDirectory(const std::string& fname) override; Status CreateDir(const std::string& dirname,
Status GetFileSize(const std::string& fname, uint64* file_size) override; TransactionToken* token) override;
Status RenameFile(const std::string& src, const std::string& target) override; Status Stat(const std::string& fname, TransactionToken* token,
Status CopyFile(const std::string& src, const std::string& target) override; FileStatistics* stat) override;
Status IsDirectory(const std::string& fname,
TransactionToken* token) override;
Status GetFileSize(const std::string& fname, TransactionToken* token,
uint64* file_size) override;
Status RenameFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
Status CopyFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
std::string TranslateName(const std::string& name) const override; std::string TranslateName(const std::string& name) const override;
void FlushCaches() override; void FlushCaches(TransactionToken* token) override;
private: private:
std::unique_ptr<TF_Filesystem> filesystem_; std::unique_ptr<TF_Filesystem> filesystem_;

View File

@ -33,7 +33,6 @@ limitations under the License.
// Windows defines the following macros to convert foo to fooA or fooW, // Windows defines the following macros to convert foo to fooA or fooW,
// depending on the type of the string argument. We don't use these macros, so // depending on the type of the string argument. We don't use these macros, so
// undefine them here. // undefine them here.
#undef LoadLibrary
#undef CopyFile #undef CopyFile
#undef DeleteFile #undef DeleteFile
#undef TranslateName #undef TranslateName

View File

@ -25,12 +25,15 @@ cc_library(
"//tensorflow:windows": get_win_copts(), "//tensorflow:windows": get_win_copts(),
}), }),
deps = [ deps = [
":expiring_lru_cache",
":gcs_helper", ":gcs_helper",
":ram_file_block_cache",
"//tensorflow/c:env", "//tensorflow/c:env",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
], ],
) )
@ -44,14 +47,6 @@ cc_library(
], ],
) )
cc_library(
name = "file_block_cache",
hdrs = ["file_block_cache.h"],
deps = [
"//tensorflow/c:tf_status",
],
)
cc_library( cc_library(
name = "cleanup", name = "cleanup",
hdrs = ["cleanup.h"], hdrs = ["cleanup.h"],
@ -63,7 +58,6 @@ cc_library(
hdrs = ["ram_file_block_cache.h"], hdrs = ["ram_file_block_cache.h"],
deps = [ deps = [
":cleanup", ":cleanup",
":file_block_cache",
"//tensorflow/c:env", "//tensorflow/c:env",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",

View File

@ -1,140 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/c/tf_status.h"
namespace tf_gcs_filesystem {
class FileBlockCache;
/// FileBlockCacheStatsInterface allows for instrumentation of the block cache.
///
/// FileBlockCacheStatsInterface and its subclasses must be safe to use from
/// multiple threads concurrently.
///
/// WARNING! This is an experimental interface that may change or go away at any
/// time.
class FileBlockCacheStatsInterface {
public:
/// Configure is called to provide instrumentation hooks.
///
/// Note: Configure can be called multiple times (e.g. if the block cache is
/// re-initialized).
virtual void Configure(const FileBlockCache* block_cache) = 0;
/// RecordBlockLoadRequest is called to record the size of a hit block.
virtual void RecordCacheHitBlockSize(size_t bytes_transferred) = 0;
/// RecordBlockLoadRequest is called to record the size of a missed block.
virtual void RecordCacheMissBlockSize(size_t bytes_transferred) = 0;
virtual ~FileBlockCacheStatsInterface() = default;
};
/// \brief A block cache of file contents, keyed by {filename, offset}.
///
/// This class should be shared by read-only random access files on a remote
/// filesystem (e.g. GCS).
class FileBlockCache {
public:
/// The callback executed when a block is not found in the cache, and needs to
/// be fetched from the backing filesystem. This callback is provided when the
/// cache is constructed. The `status` should be `TF_OK` as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
typedef std::function<void(const std::string& filename, size_t offset,
size_t buffer_size, char* buffer,
size_t* bytes_transferred, TF_Status* status)>
BlockFetcher;
virtual ~FileBlockCache() {}
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This
/// method will set `status` to:
///
/// 1) The error from the remote filesystem, if the read from the remote
/// filesystem failed.
/// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem
/// succeeded,
/// but the read returned a partial block, and the LRU cache contained a
/// block at a higher offset (indicating that the partial block should have
/// been a full block).
/// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but
/// the file contents do not extend past `offset` and thus nothing was
/// placed in `out`.
/// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was
/// placed
/// in `buffer`).
///
/// Caller is responsible for allocating memory for `buffer`.
/// `buffer` will be left unchanged in case of errors.
virtual void Read(const std::string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) = 0;
// Validate the given file signature with the existing file signature in the
// cache. Returns true if the signature doesn't change or the file did not
// exist before. If the signature changes, update the existing signature with
// the new one and remove the file from cache.
virtual bool ValidateAndUpdateFileSignature(const std::string& filename,
int64_t file_signature) = 0;
/// Remove all cached blocks for `filename`.
virtual void RemoveFile(const std::string& filename) = 0;
/// Remove all cached data.
virtual void Flush() = 0;
/// Accessors for cache parameters.
virtual size_t block_size() const = 0;
virtual size_t max_bytes() const = 0;
virtual uint64_t max_staleness() const = 0;
/// The current size (in bytes) of the cache.
virtual size_t CacheSize() const = 0;
// Returns true if the cache is enabled. If false, the BlockFetcher callback
// is always executed during Read.
virtual bool IsCacheEnabled() const = 0;
void SetStats(FileBlockCacheStatsInterface* stats) {
if (stats == nullptr) {
std::cerr
<< "Attempted to monitor a NULL stats object. This may prevent the "
"corresponding monitoring data from being exported";
return;
}
cache_stats_ = stats;
cache_stats_->Configure(this);
}
protected:
FileBlockCacheStatsInterface* cache_stats_ = nullptr; // Not owned.
};
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <string.h> #include <string.h>
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "google/cloud/storage/client.h" #include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h" #include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
@ -27,6 +29,27 @@ limitations under the License.
// This filesystem will support `gs://` URI schemes. // This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage; namespace gcs = google::cloud::storage;
// The environment variable that overrides the block size for aligned reads from
// GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
constexpr size_t kDefaultBlockSize = 64 * 1024 * 1024;
// The environment variable that overrides the max size of the LRU cache of
// blocks read from GCS. Specified in MB.
constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB";
constexpr size_t kDefaultMaxCacheSize = 0;
// The environment variable that overrides the maximum staleness of cached file
// contents. Once any block of a file reaches this staleness, all cached blocks
// will be evicted on the next read.
constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS";
constexpr uint64_t kDefaultMaxStaleness = 0;
constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE";
constexpr uint64_t kStatCacheDefaultMaxAge = 5;
// The environment variable that overrides the maximum number of entries in the
// Stat cache.
constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES";
constexpr size_t kStatCacheDefaultMaxEntries = 1024;
// How to upload new data when Flush() is called multiple times. // How to upload new data when Flush() is called multiple times.
// By default the entire file is reuploaded. // By default the entire file is reuploaded.
constexpr char kAppendMode[] = "GCS_APPEND_MODE"; constexpr char kAppendMode[] = "GCS_APPEND_MODE";
@ -81,28 +104,16 @@ static void MaybeAppendSlash(std::string* name) {
name->push_back('/'); name->push_back('/');
} }
// SECTION 1. Implementation for `TF_RandomAccessFile` // A helper function to actually read the data from GCS.
// ---------------------------------------------------------------------------- static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
namespace tf_random_access_file { size_t buffer_size, char* buffer,
typedef struct GCSFile { tf_gcs_filesystem::GCSFile* gcs_file,
const std::string bucket; TF_Status* status) {
const std::string object; std::string bucket, object;
gcs::Client* gcs_client; // not owned ParseGCSPath(path, false, &bucket, &object, status);
} GCSFile; if (TF_GetCode(status) != TF_OK) return -1;
auto stream = gcs_file->gcs_client.ReadObject(
void Cleanup(TF_RandomAccessFile* file) { bucket, object, gcs::ReadRange(offset, offset + buffer_size));
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// TODO(vnvo2409): Adding cache.
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
auto stream = gcs_file->gcs_client->ReadObject(
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
TF_SetStatusFromGCSStatus(stream.status(), status); TF_SetStatusFromGCSStatus(stream.status(), status);
if ((TF_GetCode(status) != TF_OK) && if ((TF_GetCode(status) != TF_OK) &&
(TF_GetCode(status) != TF_OUT_OF_RANGE)) { (TF_GetCode(status) != TF_OUT_OF_RANGE)) {
@ -111,16 +122,119 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
int64_t read; int64_t read;
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second, if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
&read)) { &read)) {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); // When we read a file with offset that is bigger than the actual file size.
return -1; // GCS will return an empty header (e.g no `content-length` header). In this
} // case, we will set read to `0` and continue.
if (read != n) { if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); read = 0;
} else {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
} }
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
TF_SetStatus(status, TF_OK, "");
stream.read(buffer, read); stream.read(buffer, read);
read = stream.gcount();
if (read < buffer_size) {
// Check stat cache to see if we encountered an interrupted read.
tf_gcs_filesystem::GcsFileStat stat;
if (gcs_file->stat_cache->Lookup(path, &stat)) {
if (offset + read < stat.base.length) {
TF_SetStatus(status, TF_INTERNAL,
absl::StrCat("File contents are inconsistent for file: ",
path, " @ ", offset)
.c_str());
}
}
}
return read; return read;
} }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
using ReadFn =
std::function<int64_t(const std::string& path, uint64_t offset, size_t n,
char* buffer, TF_Status* status)>;
typedef struct GCSFile {
const std::string path;
const bool is_cache_enable;
const uint64_t buffer_size;
ReadFn read_fn;
absl::Mutex buffer_mutex;
uint64_t buffer_start ABSL_GUARDED_BY(buffer_mutex);
bool buffer_end_is_past_eof ABSL_GUARDED_BY(buffer_mutex);
std::string buffer ABSL_GUARDED_BY(buffer_mutex);
GCSFile(std::string path, bool is_cache_enable, uint64_t buffer_size,
ReadFn read_fn)
: path(path),
is_cache_enable(is_cache_enable),
buffer_size(buffer_size),
read_fn(std::move(read_fn)),
buffer_mutex(),
buffer_start(0),
buffer_end_is_past_eof(false),
buffer() {}
} GCSFile;
void Cleanup(TF_RandomAccessFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->is_cache_enable || n > gcs_file->buffer_size) {
return gcs_file->read_fn(gcs_file->path, offset, n, buffer, status);
} else {
absl::MutexLock l(&gcs_file->buffer_mutex);
size_t buffer_end = gcs_file->buffer_start + gcs_file->buffer.size();
size_t copy_size = 0;
if (offset < buffer_end && gcs_file->buffer_start) {
copy_size = (std::min)(n, static_cast<size_t>(buffer_end - offset));
memcpy(buffer,
gcs_file->buffer.data() + (offset - gcs_file->buffer_start),
copy_size);
}
bool consumed_buffer_to_eof =
offset + copy_size >= buffer_end && gcs_file->buffer_end_is_past_eof;
if (copy_size < n && !consumed_buffer_to_eof) {
gcs_file->buffer_start = offset + copy_size;
gcs_file->buffer.resize(gcs_file->buffer_size);
auto read_fill_buffer = gcs_file->read_fn(
gcs_file->path, gcs_file->buffer_start, gcs_file->buffer_size,
&(gcs_file->buffer[0]), status);
gcs_file->buffer_end_is_past_eof =
(TF_GetCode(status) == TF_OUT_OF_RANGE);
if (read_fill_buffer >= 0) gcs_file->buffer.resize(read_fill_buffer);
if (TF_GetCode(status) != TF_OK &&
TF_GetCode(status) != TF_OUT_OF_RANGE) {
// Empty the buffer to avoid caching bad reads.
gcs_file->buffer.resize(0);
return -1;
}
size_t remaining_copy =
(std::min)(n - copy_size, gcs_file->buffer.size());
memcpy(buffer + copy_size, gcs_file->buffer.data(), remaining_copy);
copy_size += remaining_copy;
}
if (copy_size < n) {
// Forget the end-of-file flag to allow for clients that poll on the
// same file.
gcs_file->buffer_end_is_past_eof = false;
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
return copy_size;
}
TF_SetStatus(status, TF_OK, "");
return copy_size;
}
}
} // namespace tf_random_access_file } // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile` // SECTION 2. Implementation for `TF_WritableFile`
@ -289,11 +403,87 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) {
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem // SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_gcs_filesystem { namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
// TODO(vnvo2409): Use partial reponse for better performance. // TODO(vnvo2409): Use partial reponse for better performance.
// TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`. // TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`.
// TODO(vnvo2409): Refactor the filesystem implementation when // TODO(vnvo2409): Refactor the filesystem implementation when
// https://github.com/googleapis/google-cloud-cpp/issues/4482 is done. // https://github.com/googleapis/google-cloud-cpp/issues/4482 is done.
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
: gcs_client(gcs_client), block_cache_lock() {
const char* append_mode = std::getenv(kAppendMode);
compose = (append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
uint64_t value;
block_size = kDefaultBlockSize;
size_t max_bytes = kDefaultMaxCacheSize;
uint64_t max_staleness = kDefaultMaxStaleness;
// Apply the overrides for the block size (MB), max bytes (MB), and max
// staleness (seconds) if provided.
if (absl::SimpleAtoi(std::getenv(kBlockSize), &value)) {
block_size = value * 1024 * 1024;
}
if (absl::SimpleAtoi(std::getenv(kMaxCacheSize), &value)) {
max_bytes = static_cast<size_t>(value * 1024 * 1024);
}
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
max_staleness = value;
}
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
[this](const std::string& filename, size_t offset, size_t buffer_size,
char* buffer, TF_Status* status) {
return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this,
status);
});
uint64_t stat_cache_max_age = kStatCacheDefaultMaxAge;
size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries;
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxAge), &value)) {
stat_cache_max_age = value;
}
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxEntries), &value)) {
stat_cache_max_entries = static_cast<size_t>(value);
}
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
stat_cache_max_age, stat_cache_max_entries);
}
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client, bool compose,
uint64_t block_size, size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries)
: gcs_client(gcs_client),
compose(compose),
block_cache_lock(),
block_size(block_size) {
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
[this](const std::string& filename, size_t offset, size_t buffer_size,
char* buffer, TF_Status* status) {
return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this,
status);
});
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
stat_cache_max_age, stat_cache_max_entries);
}
void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size,
size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries,
TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem =
new GCSFile(std::move(client.value()), compose, block_size, max_bytes,
max_staleness, stat_cache_max_age, stat_cache_max_entries);
TF_SetStatus(status, TF_OK, "");
}
void Init(TF_Filesystem* filesystem, TF_Status* status) { void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client = google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient(); gcs::Client::CreateDefaultClient();
@ -302,12 +492,7 @@ void Init(TF_Filesystem* filesystem, TF_Status* status) {
return; return;
} }
const char* append_mode = std::getenv(kAppendMode); filesystem->plugin_filesystem = new GCSFile(std::move(client.value()));
bool compose =
(append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
filesystem->plugin_filesystem =
new GCSFile({std::move(client.value()), compose});
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
@ -316,6 +501,19 @@ void Cleanup(TF_Filesystem* filesystem) {
delete gcs_file; delete gcs_file;
} }
static void UncachedStatForObject(const std::string& bucket,
const std::string& object, GcsFileStat* stat,
gcs::Client* gcs_client, TF_Status* status) {
auto metadata = gcs_client->GetObjectMetadata(bucket, object);
if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status);
stat->generation_number = metadata->generation();
stat->base.length = metadata->size();
stat->base.mtime_nsec =
metadata->time_storage_class_updated().time_since_epoch().count();
stat->base.is_directory = object.back() == '/';
return TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later // TODO(vnvo2409): Implement later
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) { TF_RandomAccessFile* file, TF_Status* status) {
@ -324,8 +522,46 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
bool is_cache_enabled;
{
absl::MutexLock l(&gcs_file->block_cache_lock);
is_cache_enabled = gcs_file->file_block_cache->IsCacheEnabled();
}
auto read_fn = [gcs_file, is_cache_enabled, bucket, object](
const std::string& path, uint64_t offset, size_t n,
char* buffer, TF_Status* status) -> int64_t {
int64_t read = 0;
if (is_cache_enabled) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
GcsFileStat stat;
gcs_file->stat_cache->LookupOrCompute(
path, &stat,
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
TF_Status* status) {
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
status);
},
status);
if (TF_GetCode(status) != TF_OK) return -1;
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
path, stat.generation_number)) {
std::cout
<< "File signature has been changed. Refreshing the cache. Path: "
<< path;
}
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
} else {
read = LoadBufferFromGCS(path, offset, n, buffer, gcs_file, status);
}
if (TF_GetCode(status) != TF_OK) return -1;
if (read < n)
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
else
TF_SetStatus(status, TF_OK, "");
return read;
};
file->plugin_file = new tf_random_access_file::GCSFile( file->plugin_file = new tf_random_access_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client}); std::move(path), is_cache_enabled, gcs_file->block_size, read_fn);
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
@ -428,28 +664,179 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
} }
} }
void CreateDir(const TF_Filesystem* filesystem, const char* path, static void StatForObject(GCSFile* gcs_file, const std::string& path,
TF_Status* status) { const std::string& bucket, const std::string& object,
GcsFileStat* stat, TF_Status* status) {
if (object.empty())
return TF_SetStatus(
status, TF_INVALID_ARGUMENT,
("'object' must be a non-empty string. (File: " + path + ")").c_str());
TF_SetStatus(status, TF_OK, "");
gcs_file->stat_cache->LookupOrCompute(
path, stat,
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
TF_Status* status) {
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
status);
},
status);
}
static bool ObjectExists(GCSFile* gcs_file, const std::string& path,
const std::string& bucket, const std::string& object,
TF_Status* status) {
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return !stat.base.is_directory;
}
static bool BucketExists(GCSFile* gcs_file, const std::string& bucket,
TF_Status* status) {
auto metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static std::vector<std::string> GetChildrenBounded(
GCSFile* gcs_file, std::string dir, uint64_t max_results, bool recursive,
bool include_self_directory_marker, TF_Status* status) {
std::string bucket, prefix;
MaybeAppendSlash(&dir);
ParseGCSPath(dir, true, &bucket, &prefix, status);
std::vector<std::string> result;
uint64_t count = 0;
std::string delimiter = recursive ? "" : "/";
for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes(
bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter))) {
if (count == max_results) {
TF_SetStatus(status, TF_OK, "");
return result;
}
if (!item) {
TF_SetStatusFromGCSStatus(item.status(), status);
return result;
}
auto value = *std::move(item);
std::string children = absl::holds_alternative<std::string>(value)
? absl::get<std::string>(value)
: absl::get<gcs::ObjectMetadata>(value).name();
auto pos = children.find(prefix);
if (pos != 0) {
TF_SetStatus(status, TF_INTERNAL,
("Unexpected response: the returned file name " + children +
" doesn't match the prefix " + prefix)
.c_str());
return result;
}
children.erase(0, prefix.length());
if (!children.empty() || include_self_directory_marker) {
result.emplace_back(children);
}
++count;
}
return result;
}
static bool FolderExists(GCSFile* gcs_file, std::string dir,
TF_Status* status) {
ExpiringLRUCache<GcsFileStat>::ComputeFunc compute_func =
[gcs_file](const std::string& dir, GcsFileStat* stat, TF_Status* status) {
auto children =
GetChildrenBounded(gcs_file, dir, 1, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
if (!children.empty()) {
stat->base = {0, 0, true};
return TF_SetStatus(status, TF_OK, "");
} else {
return TF_SetStatus(status, TF_INVALID_ARGUMENT, "Not a directory!");
}
};
GcsFileStat stat;
MaybeAppendSlash(&dir);
gcs_file->stat_cache->LookupOrCompute(dir, &stat, compute_func, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_INVALID_ARGUMENT)
return false;
if (TF_GetCode(status) == TF_INVALID_ARGUMENT) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static void ClearFileCaches(GCSFile* gcs_file, const std::string& path) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->RemoveFile(path);
gcs_file->stat_cache->Delete(path);
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object; std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status); ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) { if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); bool result = BucketExists(gcs_file, bucket, status);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); if (result) return TF_SetStatus(status, TF_OK, "");
}
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_NOT_FOUND) return;
bool result = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK || (TF_GetCode(status) == TF_OK && result))
return;
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string dir = path;
MaybeAppendSlash(&dir);
std::string bucket, object;
ParseGCSPath(dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
bool is_directory = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return;
if (!is_directory)
TF_SetStatus(status, TF_NOT_FOUND,
("The specified bucket " + dir + " was not found.").c_str());
return; return;
} }
MaybeAppendSlash(&object); PathExists(filesystem, dir.c_str(), status);
auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); if (TF_GetCode(status) == TF_OK)
TF_SetStatusFromGCSStatus(object_metadata.status(), status); return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
if (TF_GetCode(status) == TF_NOT_FOUND) {
auto insert_metadata = auto metadata = gcs_file->gcs_client.InsertObject(
gcs_file->gcs_client.InsertObject(bucket, object, ""); bucket, object, "",
TF_SetStatusFromGCSStatus(insert_metadata.status(), status); // Adding this parameter means HTTP_CODE_PRECONDITION_FAILED
} else if (TF_GetCode(status) == TF_OK) { // will be returned if the object already exists, so avoid reuploading.
gcs::IfGenerationMatch(0));
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
TF_SetStatus(status, TF_ALREADY_EXISTS, path); TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
} }
// TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the // TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the
@ -465,79 +852,31 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object); auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status); TF_SetStatusFromGCSStatus(gcs_status, status);
if (TF_GetCode(status) == TF_OK) ClearFileCaches(gcs_file, path);
} }
// Checks that the directory is empty (i.e no objects with this prefix exist).
// Deletes the GCS directory marker if it exists.
void DeleteDir(const TF_Filesystem* filesystem, const char* path, void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
std::string bucket, object; // A directory is considered empty either if there are no matching objects
ParseGCSPath(path, false, &bucket, &object, status); // with the corresponding name prefix or if there is exactly one matching
if (TF_GetCode(status) != TF_OK) return; // object and it is the directory marker. Therefore we need to retrieve
MaybeAppendSlash(&object); // at most two children for the prefix to detect if a directory is empty.
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
int object_count = 0; auto childrens = GetChildrenBounded(gcs_file, path, 2, true, true, status);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
++object_count;
// We consider a path is a non-empty directory in two cases:
// - There are more than two objects whose keys start with the name of this
// directory.
// - There is one object whose key contains the name of this directory ( but
// not equal ).
if (object_count > 1 || metadata->name() != object) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
return;
}
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
}
// TODO(vnvo2409): `DeleteRecursively` needs `GetChildrens` but there will be
// some differents compared to the default implementation. Will be refactored.
static void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files,
uint64_t* undeleted_dirs, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
if (childrens.size() > 1 || (childrens.size() == 1 && !childrens[0].empty()))
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem); return TF_SetStatus(status, TF_FAILED_PRECONDITION,
auto gcs_status = gcs::DeleteByPrefix(gcs_file->gcs_client, bucket, object); "Cannot delete a non-empty directory.");
TF_SetStatusFromGCSStatus(gcs_status, status); if (childrens.size() == 1 && childrens[0].empty()) {
if (TF_GetCode(status) != TF_OK) return; // This is the directory marker object. Delete it.
*undeleted_dirs = 0; std::string dir = path;
*undeleted_files = 0; MaybeAppendSlash(&dir);
} DeleteFile(filesystem, dir.c_str(), status);
// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND`
// if the object does not exist. In that case, we will have to check if the
// `src` is a directory or not to set the correspondent `status` (i.e
// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if
// path `src` is a directory).
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return; return;
} }
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src); TF_SetStatus(status, TF_OK, "");
TF_SetStatusFromGCSStatus(gcs_status, status);
} }
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
@ -556,6 +895,183 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_SetStatusFromGCSStatus(metadata.status(), status); TF_SetStatusFromGCSStatus(metadata.status(), status);
} }
bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return false;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
bool result = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return false;
if (!result)
TF_SetStatus(
status, TF_NOT_FOUND,
("The specified bucket gs://" + bucket + " was not found.").c_str());
return result;
}
bool is_folder = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_folder) return true;
bool is_object = ObjectExists(gcs_file, path, bucket, object, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_object) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
absl::StrCat("The specified path ", path, " is not a directory.")
.c_str());
return false;
}
TF_SetStatus(status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
return false;
}
static void RenameObject(const TF_Filesystem* filesystem,
const std::string& src, const std::string& dst,
TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK) return;
ClearFileCaches(gcs_file, dst);
DeleteFile(filesystem, src.c_str(), status);
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
if (!IsDirectory(filesystem, src, status)) {
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
RenameObject(filesystem, src, dst, status);
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, src, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string src_dir = src;
std::string dst_dir = dst;
MaybeAppendSlash(&src_dir);
MaybeAppendSlash(&dst_dir);
for (const std::string& children : childrens) {
RenameObject(filesystem, src_dir + children, dst_dir + children, status);
if (TF_GetCode(status) != TF_OK) return;
}
TF_SetStatus(status, TF_OK, "");
}
void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files, uint64_t* undeleted_dirs,
TF_Status* status) {
if (!undeleted_files || !undeleted_dirs)
return TF_SetStatus(
status, TF_INTERNAL,
"'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
*undeleted_files = 0;
*undeleted_dirs = 0;
if (!IsDirectory(filesystem, path, status)) {
*undeleted_dirs = 1;
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string dir = path;
MaybeAppendSlash(&dir);
for (const std::string& children : childrens) {
const std::string& full_path = dir + children;
DeleteFile(filesystem, full_path.c_str(), status);
if (TF_GetCode(status) != TF_OK) {
if (IsDirectory(filesystem, full_path.c_str(), status))
// The object is a directory marker.
(*undeleted_dirs)++;
else
(*undeleted_files)++;
}
}
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, false, false, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = childrens.size();
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++)
(*entries)[i] = strdup(childrens[i].c_str());
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK) {
stats->is_directory = true;
stats->length = 0;
stats->mtime_nsec = 0;
}
return;
}
if (IsDirectory(filesystem, path, status)) {
stats->is_directory = true;
stats->length = 0;
stats->mtime_nsec = 0;
return TF_SetStatus(status, TF_OK, "");
}
if (TF_GetCode(status) == TF_OK) {
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
if (metadata) {
stats->is_directory = false;
stats->length = metadata.value().size();
stats->mtime_nsec = metadata.value()
.time_storage_class_updated()
.time_since_epoch()
.count();
}
TF_SetStatusFromGCSStatus(metadata.status(), status);
}
}
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
static void FlushCaches(const TF_Filesystem* filesystem) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->Flush();
gcs_file->stat_cache->Clear();
}
} // namespace tf_gcs_filesystem } // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -572,6 +1088,13 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>( ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init; ops->filesystem_ops->init = tf_gcs_filesystem::Init;
@ -581,6 +1104,20 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile; ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file = ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile; tf_gcs_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_gcs_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_gcs_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_gcs_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_gcs_filesystem::DeleteDir;
ops->filesystem_ops->delete_recursively =
tf_gcs_filesystem::DeleteRecursively;
ops->filesystem_ops->copy_file = tf_gcs_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_gcs_filesystem::PathExists;
ops->filesystem_ops->is_directory = tf_gcs_filesystem::IsDirectory;
ops->filesystem_ops->stat = tf_gcs_filesystem::Stat;
ops->filesystem_ops->get_children = tf_gcs_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_gcs_filesystem::TranslateName;
ops->filesystem_ops->flush_caches = tf_gcs_filesystem::FlushCaches;
} }
void TF_InitPlugin(TF_FilesystemPluginInfo* info) { void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -17,6 +17,8 @@
#include "google/cloud/storage/client.h" #include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
void ParseGCSPath(const std::string& fname, bool object_empty_ok, void ParseGCSPath(const std::string& fname, bool object_empty_ok,
@ -45,10 +47,34 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region);
} // namespace tf_read_only_memory_region } // namespace tf_read_only_memory_region
namespace tf_gcs_filesystem { namespace tf_gcs_filesystem {
typedef struct GcsFileStat {
TF_FileStatistics base;
int64_t generation_number;
} GcsFileStat;
typedef struct GCSFile { typedef struct GCSFile {
google::cloud::storage::Client gcs_client; // owned google::cloud::storage::Client gcs_client; // owned
bool compose; bool compose;
absl::Mutex block_cache_lock;
std::shared_ptr<RamFileBlockCache> file_block_cache
ABSL_GUARDED_BY(block_cache_lock);
uint64_t block_size; // Reads smaller than block_size will trigger a read
// of block_size.
std::unique_ptr<ExpiringLRUCache<GcsFileStat>> stat_cache;
GCSFile(google::cloud::storage::Client&& gcs_client);
// This constructor is used for testing purpose only.
GCSFile(google::cloud::storage::Client&& gcs_client, bool compose,
uint64_t block_size, size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries);
} GCSFile; } GCSFile;
// This function is used to initialize a filesystem without the need of setting
// manually environement variables.
void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size,
size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries,
TF_Status* status);
void Init(TF_Filesystem* filesystem, TF_Status* status); void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem); void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,

View File

@ -66,6 +66,9 @@ static std::string* GetTmpDir() {
namespace tensorflow { namespace tensorflow {
namespace { namespace {
// TODO(vnvo2409): Refactor `gcs_filesystem_test` to remove unnecessary tests
// after porting all tests from
// `//tensorflow/core/platform/cloud:gcs_file_system_test`.
class GCSFilesystemTest : public ::testing::Test { class GCSFilesystemTest : public ::testing::Test {
public: public:
void SetUp() override { void SetUp() override {
@ -74,13 +77,14 @@ class GCSFilesystemTest : public ::testing::Test {
::testing::UnitTest::GetInstance()->current_test_info()->name()); ::testing::UnitTest::GetInstance()->current_test_info()->name());
status_ = TF_NewStatus(); status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem; filesystem_ = new TF_Filesystem;
tf_gcs_filesystem::Init(filesystem_, status_); filesystem_->plugin_filesystem = nullptr;
ASSERT_TF_OK(status_) << "Could not initialize filesystem. " // Because different tests requires different setup for filesystem. We
<< TF_Message(status_); // initialize filesystem in each testcase.
} }
void TearDown() override { void TearDown() override {
TF_DeleteStatus(status_); TF_DeleteStatus(status_);
tf_gcs_filesystem::Cleanup(filesystem_); if (filesystem_->plugin_filesystem != nullptr)
tf_gcs_filesystem::Cleanup(filesystem_);
delete filesystem_; delete filesystem_;
} }
@ -117,6 +121,21 @@ class GCSFilesystemTest : public ::testing::Test {
} }
} }
::testing::AssertionResult InsertObject(const std::string& path,
const std::string& content,
gcs::Client* gcs_client,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK)
return ::testing::AssertionFailure() << TF_Message(status);
auto metadata = gcs_client->InsertObject(bucket, object, content);
if (metadata)
return ::testing::AssertionSuccess();
else
return ::testing::AssertionFailure() << metadata.status().message();
}
::testing::AssertionResult CompareSubString(int64_t offset, size_t length, ::testing::AssertionResult CompareSubString(int64_t offset, size_t length,
absl::string_view result, absl::string_view result,
size_t read) { size_t read) {
@ -172,6 +191,9 @@ TEST_F(GCSFilesystemTest, ParseGCSPath) {
} }
TEST_F(GCSFilesystemTest, RandomAccessFile) { TEST_F(GCSFilesystemTest, RandomAccessFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string filepath = GetURIForPath("a_file"); std::string filepath = GetURIForPath("a_file");
TF_RandomAccessFile* file = new TF_RandomAccessFile; TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file, tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file,
@ -208,6 +230,9 @@ TEST_F(GCSFilesystemTest, RandomAccessFile) {
} }
TEST_F(GCSFilesystemTest, WritableFile) { TEST_F(GCSFilesystemTest, WritableFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string filepath = GetURIForPath("a_file"); std::string filepath = GetURIForPath("a_file");
TF_WritableFile* file = new TF_WritableFile; TF_WritableFile* file = new TF_WritableFile;
tf_gcs_filesystem::NewWritableFile(filesystem_, filepath.c_str(), file, tf_gcs_filesystem::NewWritableFile(filesystem_, filepath.c_str(), file,
@ -273,6 +298,9 @@ TEST_F(GCSFilesystemTest, WritableFile) {
} }
TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) { TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file"); std::string path = GetURIForPath("a_file");
auto gcs_file = auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem); static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
@ -298,6 +326,131 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
delete region; delete region;
} }
// These tests below are ported from
// `//tensorflow/core/platform/cloud:gcs_file_system_test`
TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) {
tf_gcs_filesystem::InitTest(filesystem_, false, 0, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(6);
int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_);
ASSERT_EQ(read, 6) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "012345") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 6, 6, &result[0], status_);
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "6789") << "Result: " << result << "\n";
}
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered) {
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(6);
int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_);
ASSERT_EQ(read, 6) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "012345") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 6, 6, &result[0], status_);
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "6789") << "Result: " << result << "\n";
}
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) {
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(10);
int64_t read = tf_random_access_file::Read(file, 0, result.length(),
&result[0], status_);
ASSERT_EQ(read, 10) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "0123456789") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, result.length(), result.length(),
&result[0], status_);
ASSERT_EQ(read, 0) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "") << "Result: " << result << "\n";
}
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) {
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "012345678", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(5);
int64_t read = tf_random_access_file::Read(file, 0, result.length(),
&result[0], status_);
ASSERT_EQ(read, 5) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "01234") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 4, result.length(), &result[0],
status_);
ASSERT_EQ(read, 5) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
result.resize(read);
ASSERT_EQ(result, "45678") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 5, result.length(), &result[0],
status_);
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "5678") << "Result: " << result << "\n";
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -39,9 +39,6 @@ std::shared_ptr<RamFileBlockCache::Block> RamFileBlockCache::Lookup(
auto entry = block_map_.find(key); auto entry = block_map_.find(key);
if (entry != block_map_.end()) { if (entry != block_map_.end()) {
if (BlockNotStale(entry->second)) { if (BlockNotStale(entry->second)) {
if (cache_stats_ != nullptr) {
cache_stats_->RecordCacheHitBlockSize(entry->second->data.size());
}
return entry->second; return entry->second;
} else { } else {
// Remove the stale block and continue. // Remove the stale block and continue.
@ -136,12 +133,9 @@ void RamFileBlockCache::MaybeFetch(const Key& key,
block->mu.Unlock(); // Release the lock while making the API call. block->mu.Unlock(); // Release the lock while making the API call.
block->data.clear(); block->data.clear();
block->data.resize(block_size_, 0); block->data.resize(block_size_, 0);
size_t bytes_transferred; int64_t bytes_transferred;
block_fetcher_(key.first, key.second, block_size_, block->data.data(), bytes_transferred = block_fetcher_(key.first, key.second, block_size_,
&bytes_transferred, status); block->data.data(), status);
if (cache_stats_ != nullptr) {
cache_stats_->RecordCacheMissBlockSize(bytes_transferred);
}
block->mu.Lock(); // Reacquire the lock immediately afterwards block->mu.Lock(); // Reacquire the lock immediately afterwards
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
block->data.resize(bytes_transferred, 0); block->data.resize(bytes_transferred, 0);
@ -171,18 +165,16 @@ void RamFileBlockCache::MaybeFetch(const Key& key,
"Control flow should never reach the end of RamFileBlockCache::Fetch."); "Control flow should never reach the end of RamFileBlockCache::Fetch.");
} }
void RamFileBlockCache::Read(const std::string& filename, size_t offset, int64_t RamFileBlockCache::Read(const std::string& filename, size_t offset,
size_t n, char* buffer, size_t* bytes_transferred, size_t n, char* buffer, TF_Status* status) {
TF_Status* status) {
*bytes_transferred = 0;
if (n == 0) { if (n == 0) {
return TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
return 0;
} }
if (!IsCacheEnabled() || (n > max_bytes_)) { if (!IsCacheEnabled() || (n > max_bytes_)) {
// The cache is effectively disabled, so we pass the read through to the // The cache is effectively disabled, so we pass the read through to the
// fetcher without breaking it up into blocks. // fetcher without breaking it up into blocks.
return block_fetcher_(filename, offset, n, buffer, bytes_transferred, return block_fetcher_(filename, offset, n, buffer, status);
status);
} }
// Calculate the block-aligned start and end of the read. // Calculate the block-aligned start and end of the read.
size_t start = block_size_ * (offset / block_size_); size_t start = block_size_ * (offset / block_size_);
@ -202,20 +194,20 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset,
abort(); abort();
} }
MaybeFetch(key, block, status); MaybeFetch(key, block, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return -1;
UpdateLRU(key, block, status); UpdateLRU(key, block, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return -1;
// Copy the relevant portion of the block into the result buffer. // Copy the relevant portion of the block into the result buffer.
const auto& data = block->data; const auto& data = block->data;
if (offset >= pos + data.size()) { if (offset >= pos + data.size()) {
// The requested offset is at or beyond the end of the file. This can // The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last // happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`. // block in the file, which does not extend all the way out to `offset`.
*bytes_transferred = total_bytes_transferred;
std::stringstream os; std::stringstream os;
os << "EOF at offset " << offset << " in file " << filename os << "EOF at offset " << offset << " in file " << filename
<< " at position " << pos << " with data size " << data.size(); << " at position " << pos << " with data size " << data.size();
return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str()); TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str());
return total_bytes_transferred;
} }
auto begin = data.begin(); auto begin = data.begin();
if (offset > pos) { if (offset > pos) {
@ -237,8 +229,8 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset,
break; break;
} }
} }
*bytes_transferred = total_bytes_transferred; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return total_bytes_transferred;
} }
bool RamFileBlockCache::ValidateAndUpdateFileSignature( bool RamFileBlockCache::ValidateAndUpdateFileSignature(

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h" #include "absl/synchronization/notification.h"
#include "tensorflow/c/env.h" #include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
namespace tf_gcs_filesystem { namespace tf_gcs_filesystem {
@ -37,16 +36,17 @@ namespace tf_gcs_filesystem {
/// ///
/// This class should be shared by read-only random access files on a remote /// This class should be shared by read-only random access files on a remote
/// filesystem (e.g. GCS). /// filesystem (e.g. GCS).
class RamFileBlockCache : public FileBlockCache { class RamFileBlockCache {
public: public:
/// The callback executed when a block is not found in the cache, and needs to /// The callback executed when a block is not found in the cache, and needs to
/// be fetched from the backing filesystem. This callback is provided when the /// be fetched from the backing filesystem. This callback is provided when the
/// cache is constructed. The `status` should be `TF_OK` as long as the /// cache is constructed. It returns total bytes read ( -1 in case of errors
/// read from the remote filesystem succeeded (similar to the semantics of the /// ). The `status` should be `TF_OK` as long as the read from the remote
/// read(2) system call). /// filesystem succeeded (similar to the semantics of the read(2) system
typedef std::function<void(const std::string& filename, size_t offset, /// call).
size_t buffer_size, char* buffer, typedef std::function<int64_t(const std::string& filename, size_t offset,
size_t* bytes_transferred, TF_Status* status)> size_t buffer_size, char* buffer,
TF_Status* status)>
BlockFetcher; BlockFetcher;
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness, RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness,
@ -66,10 +66,10 @@ class RamFileBlockCache : public FileBlockCache {
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
} }
std::cout << "GCS file block cache is " std::cout << "GCS file block cache is "
<< (IsCacheEnabled() ? "enabled" : "disabled"); << (IsCacheEnabled() ? "enabled" : "disabled") << ".\n";
} }
~RamFileBlockCache() override { ~RamFileBlockCache() {
if (pruning_thread_) { if (pruning_thread_) {
stop_pruning_thread_.Notify(); stop_pruning_thread_.Notify();
// Destroying pruning_thread_ will block until Prune() receives the above // Destroying pruning_thread_ will block until Prune() receives the above
@ -78,8 +78,9 @@ class RamFileBlockCache : public FileBlockCache {
} }
} }
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This /// Read `n` bytes from `filename` starting at `offset` into `buffer`. It
/// method will set `status` to: /// returns total bytes read ( -1 in case of errors ). This method will set
/// `status` to:
/// ///
/// 1) The error from the remote filesystem, if the read from the remote /// 1) The error from the remote filesystem, if the read from the remote
/// filesystem failed. /// filesystem failed.
@ -97,37 +98,34 @@ class RamFileBlockCache : public FileBlockCache {
/// ///
/// Caller is responsible for allocating memory for `buffer`. /// Caller is responsible for allocating memory for `buffer`.
/// `buffer` will be left unchanged in case of errors. /// `buffer` will be left unchanged in case of errors.
void Read(const std::string& filename, size_t offset, size_t n, char* buffer, int64_t Read(const std::string& filename, size_t offset, size_t n,
size_t* bytes_transferred, TF_Status* status) override; char* buffer, TF_Status* status);
// Validate the given file signature with the existing file signature in the // Validate the given file signature with the existing file signature in the
// cache. Returns true if the signature doesn't change or the file doesn't // cache. Returns true if the signature doesn't change or the file doesn't
// exist before. If the signature changes, update the existing signature with // exist before. If the signature changes, update the existing signature with
// the new one and remove the file from cache. // the new one and remove the file from cache.
bool ValidateAndUpdateFileSignature(const std::string& filename, bool ValidateAndUpdateFileSignature(const std::string& filename,
int64_t file_signature) override int64_t file_signature)
ABSL_LOCKS_EXCLUDED(mu_); ABSL_LOCKS_EXCLUDED(mu_);
/// Remove all cached blocks for `filename`. /// Remove all cached blocks for `filename`.
void RemoveFile(const std::string& filename) override void RemoveFile(const std::string& filename) ABSL_LOCKS_EXCLUDED(mu_);
ABSL_LOCKS_EXCLUDED(mu_);
/// Remove all cached data. /// Remove all cached data.
void Flush() override ABSL_LOCKS_EXCLUDED(mu_); void Flush() ABSL_LOCKS_EXCLUDED(mu_);
/// Accessors for cache parameters. /// Accessors for cache parameters.
size_t block_size() const override { return block_size_; } size_t block_size() const { return block_size_; }
size_t max_bytes() const override { return max_bytes_; } size_t max_bytes() const { return max_bytes_; }
uint64_t max_staleness() const override { return max_staleness_; } uint64_t max_staleness() const { return max_staleness_; }
/// The current size (in bytes) of the cache. /// The current size (in bytes) of the cache.
size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_); size_t CacheSize() const ABSL_LOCKS_EXCLUDED(mu_);
// Returns true if the cache is enabled. If false, the BlockFetcher callback // Returns true if the cache is enabled. If false, the BlockFetcher callback
// is always executed during Read. // is always executed during Read.
bool IsCacheEnabled() const override { bool IsCacheEnabled() const { return block_size_ > 0 && max_bytes_ > 0; }
return block_size_ > 0 && max_bytes_ > 0;
}
// We can not pass a lambda with capture as a function pointer to // We can not pass a lambda with capture as a function pointer to
// `TF_StartThread`, so we have to wrap `Prune` inside a static function. // `TF_StartThread`, so we have to wrap `Prune` inside a static function.

View File

@ -33,20 +33,22 @@ Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache,
std::vector<char>* out) { std::vector<char>* out) {
out->clear(); out->clear();
out->resize(n, 0); out->resize(n, 0);
size_t bytes_transferred = 0;
TF_Status status; TF_Status status;
cache->Read(filename, offset, n, out->data(), &bytes_transferred, &status); auto bytes_transferred =
EXPECT_LE(bytes_transferred, n); cache->Read(filename, offset, n, out->data(), &status);
out->resize(bytes_transferred, n); if (bytes_transferred >= 0) {
EXPECT_LE(bytes_transferred, n);
out->resize(bytes_transferred, n);
}
return status.status; return status.status;
} }
TEST(RamFileBlockCacheTest, IsCacheEnabled) { TEST(RamFileBlockCacheTest, IsCacheEnabled) {
auto fetcher = [](const string& filename, size_t offset, size_t n, auto fetcher = [](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
// Do nothing. // Do nothing.
return TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
return 0;
}; };
tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher); tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher); tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher);
@ -62,12 +64,11 @@ TEST(RamFileBlockCacheTest, IsCacheEnabled) {
TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
calls++; calls++;
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
string filename = "file"; string filename = "file";
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
@ -96,15 +97,14 @@ TEST(RamFileBlockCacheTest, PassThrough) {
int calls = 0; int calls = 0;
auto fetcher = [&calls, want_filename, want_offset, want_n]( auto fetcher = [&calls, want_filename, want_offset, want_n](
const string& got_filename, size_t got_offset, const string& got_filename, size_t got_offset,
size_t got_n, char* buffer, size_t* bytes_transferred, size_t got_n, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
EXPECT_EQ(got_filename, want_filename); EXPECT_EQ(got_filename, want_filename);
EXPECT_EQ(got_offset, want_offset); EXPECT_EQ(got_offset, want_offset);
EXPECT_EQ(got_n, want_n); EXPECT_EQ(got_n, want_n);
calls++; calls++;
memset(buffer, 'x', got_n); memset(buffer, 'x', got_n);
*bytes_transferred = got_n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return got_n;
}; };
// If block_size, max_bytes, or both are zero, or want_n is larger than // If block_size, max_bytes, or both are zero, or want_n is larger than
// max_bytes the cache is a pass-through. // max_bytes the cache is a pass-through.
@ -133,16 +133,17 @@ TEST(RamFileBlockCacheTest, BlockAlignment) {
} }
// The fetcher just fetches slices of the buffer. // The fetcher just fetches slices of the buffer.
auto fetcher = [&buf](const string& filename, size_t offset, size_t n, auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) { int64_t bytes_transferred;
if (offset < buf.size()) { if (offset < buf.size()) {
size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n); size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
memcpy(buffer, buf.data() + offset, bytes_to_copy); memcpy(buffer, buf.data() + offset, bytes_to_copy);
*bytes_transferred = bytes_to_copy; bytes_transferred = bytes_to_copy;
} else { } else {
*bytes_transferred = 0; bytes_transferred = 0;
} }
return TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
return bytes_transferred;
}; };
for (size_t block_size = 2; block_size <= 4; block_size++) { for (size_t block_size = 2; block_size <= 4; block_size++) {
// Make a cache of N-byte block size (1 block) and verify that reads of // Make a cache of N-byte block size (1 block) and verify that reads of
@ -181,15 +182,14 @@ TEST(RamFileBlockCacheTest, CacheHits) {
std::set<size_t> calls; std::set<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset, auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, char* buffer, size_t n, char* buffer,
size_t* bytes_transferred, TF_Status* status) -> int64_t {
TF_Status* status) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(offset % block_size, 0);
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset; EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
calls.insert(offset); calls.insert(offset);
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
const uint32 block_count = 256; const uint32 block_count = 256;
tf_gcs_filesystem::RamFileBlockCache cache( tf_gcs_filesystem::RamFileBlockCache cache(
@ -215,8 +215,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) {
bool second_block = false; bool second_block = false;
auto fetcher = [block_size, file_size, &first_block, &second_block]( auto fetcher = [block_size, file_size, &first_block, &second_block](
const string& filename, size_t offset, size_t n, const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(offset % block_size, 0);
size_t bytes_to_copy = 0; size_t bytes_to_copy = 0;
@ -231,8 +230,8 @@ TEST(RamFileBlockCacheTest, OutOfRange) {
memset(buffer, 'x', bytes_to_copy); memset(buffer, 'x', bytes_to_copy);
second_block = true; second_block = true;
} }
*bytes_transferred = bytes_to_copy; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return bytes_to_copy;
}; };
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0, tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
fetcher); fetcher);
@ -260,14 +259,13 @@ TEST(RamFileBlockCacheTest, Inconsistent) {
const size_t block_size = 16; const size_t block_size = 16;
// This fetcher returns OK but only fills in one byte for any offset. // This fetcher returns OK but only fills in one byte for any offset.
auto fetcher = [block_size](const string& filename, size_t offset, size_t n, auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(offset % block_size, 0);
EXPECT_GE(n, 1); EXPECT_GE(n, 1);
memset(buffer, 'x', 1); memset(buffer, 'x', 1);
*bytes_transferred = 1; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return 1;
}; };
tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0, tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0,
fetcher); fetcher);
@ -286,8 +284,7 @@ TEST(RamFileBlockCacheTest, LRU) {
std::list<size_t> calls; std::list<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset, auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, char* buffer, size_t n, char* buffer,
size_t* bytes_transferred, TF_Status* status) -> int64_t {
TF_Status* status) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_FALSE(calls.empty()) << "at offset = " << offset; EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
if (!calls.empty()) { if (!calls.empty()) {
@ -295,8 +292,8 @@ TEST(RamFileBlockCacheTest, LRU) {
calls.pop_front(); calls.pop_front();
} }
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
const uint32 block_count = 2; const uint32 block_count = 2;
tf_gcs_filesystem::RamFileBlockCache cache( tf_gcs_filesystem::RamFileBlockCache cache(
@ -335,12 +332,11 @@ TEST(RamFileBlockCacheTest, LRU) {
TEST(RamFileBlockCacheTest, MaxStaleness) { TEST(RamFileBlockCacheTest, MaxStaleness) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
calls++; calls++;
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
std::vector<char> out; std::vector<char> out;
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv); std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
@ -380,8 +376,7 @@ TEST(RamFileBlockCacheTest, MaxStaleness) {
TEST(RamFileBlockCacheTest, RemoveFile) { TEST(RamFileBlockCacheTest, RemoveFile) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
calls++; calls++;
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
if (offset > 0) { if (offset > 0) {
@ -389,8 +384,8 @@ TEST(RamFileBlockCacheTest, RemoveFile) {
c = toupper(c); c = toupper(c);
} }
memset(buffer, c, n); memset(buffer, c, n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
// This cache has space for 4 blocks; we'll read from two files. // This cache has space for 4 blocks; we'll read from two files.
const size_t n = 3; const size_t n = 3;
@ -443,12 +438,11 @@ TEST(RamFileBlockCacheTest, RemoveFile) {
TEST(RamFileBlockCacheTest, Prune) { TEST(RamFileBlockCacheTest, Prune) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
calls++; calls++;
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
std::vector<char> out; std::vector<char> out;
// Our fake environment is initialized with the current timestamp. // Our fake environment is initialized with the current timestamp.
@ -509,17 +503,17 @@ TEST(RamFileBlockCacheTest, ParallelReads) {
const int callers = 4; const int callers = 4;
BlockingCounter counter(callers); BlockingCounter counter(callers);
auto fetcher = [&counter](const string& filename, size_t offset, size_t n, auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
counter.DecrementCount(); counter.DecrementCount();
if (!counter.WaitFor(std::chrono::seconds(10))) { if (!counter.WaitFor(std::chrono::seconds(10))) {
// This avoids having the test time out, which is harder to debug. // This avoids having the test time out, which is harder to debug.
return TF_SetStatus(status, TF_FAILED_PRECONDITION, TF_SetStatus(status, TF_FAILED_PRECONDITION,
"desired concurrency not reached"); "desired concurrency not reached");
return -1;
} }
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
const int block_size = 8; const int block_size = 8;
tf_gcs_filesystem::RamFileBlockCache cache( tf_gcs_filesystem::RamFileBlockCache cache(
@ -548,17 +542,16 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
Notification notification; Notification notification;
auto fetcher = [&num_requests, &notification, block_size]( auto fetcher = [&num_requests, &notification, block_size](
const string& filename, size_t offset, size_t n, const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset, 0); EXPECT_EQ(offset, 0);
num_requests++; num_requests++;
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n;
notification.Notify(); notification.Notify();
// Wait for other thread to issue read. // Wait for other thread to issue read.
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
return TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
return n;
}; };
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0, tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
fetcher); fetcher);
@ -580,12 +573,11 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
TEST(RamFileBlockCacheTest, Flush) { TEST(RamFileBlockCacheTest, Flush) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred, char* buffer, TF_Status* status) -> int64_t {
TF_Status* status) {
calls++; calls++;
memset(buffer, 'x', n); memset(buffer, 'x', n);
*bytes_transferred = n; TF_SetStatus(status, TF_OK, "");
return TF_SetStatus(status, TF_OK, ""); return n;
}; };
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
std::vector<char> out; std::vector<char> out;

View File

@ -0,0 +1,35 @@
# Experimental hadoop filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for HADOOP environments
tf_cc_shared_object(
name = "hadoop_filesystem",
framework_so = [],
linkstatic = False,
per_os_targets = 1,
visibility = ["//visibility:public"],
deps = [":hadoop_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "hadoop_filesystem_impl",
srcs = ["hadoop_filesystem.cc"],
hdrs = ["hadoop_filesystem.h"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//third_party/hadoop:hdfs",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
],
)

View File

@ -0,0 +1,660 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h"
#include <stdlib.h>
#include <string.h>
#include <functional>
#include <iostream>
#include <sstream>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
// Implementation of a filesystem for HADOOP environments.
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
void ParseHadoopPath(const std::string& fname, std::string* scheme,
std::string* namenode, std::string* path) {
size_t scheme_end = fname.find("://") + 2;
*scheme = fname.substr(0, scheme_end + 1);
size_t nn_end = fname.find("/", scheme_end + 1);
if (nn_end == std::string::npos) return;
*namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1);
*path = fname.substr(nn_end + 1);
}
void SplitArchiveNameAndPath(std::string* path, std::string* nn,
TF_Status* status) {
size_t index_end_archive_name = path->find(".har");
if (index_end_archive_name == path->npos) {
return TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"Hadoop archive path does not contain a .har extension");
}
// Case of hadoop archive. Namenode is the path to the archive.
std::ostringstream namenodestream;
namenodestream << "har://" << nn
<< path->substr(0, index_end_archive_name + 4);
*nn = namenodestream.str();
path->erase(0, index_end_archive_name + 4);
if (path->empty())
// Root of the archive
*path = "/";
return TF_SetStatus(status, TF_OK, "");
}
template <typename R, typename... Args>
void BindFunc(void* handle, const char* name, std::function<R(Args...)>* func,
TF_Status* status) {
*func = reinterpret_cast<R (*)(Args...)>(
TF_GetSymbolFromLibrary(handle, name, status));
}
class LibHDFS {
public:
explicit LibHDFS(TF_Status* status) { LoadAndBind(status); }
std::function<hdfsFS(hdfsBuilder*)> hdfsBuilderConnect;
std::function<hdfsBuilder*()> hdfsNewBuilder;
std::function<void(hdfsBuilder*, const char*)> hdfsBuilderSetNameNode;
std::function<int(const char*, char**)> hdfsConfGetStr;
std::function<int(hdfsFS, hdfsFile)> hdfsCloseFile;
std::function<tSize(hdfsFS, hdfsFile, tOffset, void*, tSize)> hdfsPread;
std::function<tSize(hdfsFS, hdfsFile, const void*, tSize)> hdfsWrite;
std::function<int(hdfsFS, hdfsFile)> hdfsHFlush;
std::function<int(hdfsFS, hdfsFile)> hdfsHSync;
std::function<tOffset(hdfsFS, hdfsFile)> hdfsTell;
std::function<hdfsFile(hdfsFS, const char*, int, int, short, tSize)>
hdfsOpenFile;
std::function<int(hdfsFS, const char*)> hdfsExists;
std::function<hdfsFileInfo*(hdfsFS, const char*, int*)> hdfsListDirectory;
std::function<void(hdfsFileInfo*, int)> hdfsFreeFileInfo;
std::function<int(hdfsFS, const char*, int recursive)> hdfsDelete;
std::function<int(hdfsFS, const char*)> hdfsCreateDirectory;
std::function<hdfsFileInfo*(hdfsFS, const char*)> hdfsGetPathInfo;
std::function<int(hdfsFS, const char*, const char*)> hdfsRename;
private:
void LoadAndBind(TF_Status* status) {
auto TryLoadAndBind = [this](const char* name, void** handle,
TF_Status* status) {
*handle = TF_LoadSharedLibrary(name, status);
if (TF_GetCode(status) != TF_OK) return;
#define BIND_HDFS_FUNC(function) \
do { \
BindFunc(*handle, #function, &function, status); \
if (TF_GetCode(status) != TF_OK) return; \
} while (0);
BIND_HDFS_FUNC(hdfsBuilderConnect);
BIND_HDFS_FUNC(hdfsNewBuilder);
BIND_HDFS_FUNC(hdfsBuilderSetNameNode);
BIND_HDFS_FUNC(hdfsConfGetStr);
BIND_HDFS_FUNC(hdfsCloseFile);
BIND_HDFS_FUNC(hdfsPread);
BIND_HDFS_FUNC(hdfsWrite);
BIND_HDFS_FUNC(hdfsHFlush);
BIND_HDFS_FUNC(hdfsTell);
BIND_HDFS_FUNC(hdfsHSync);
BIND_HDFS_FUNC(hdfsOpenFile);
BIND_HDFS_FUNC(hdfsExists);
BIND_HDFS_FUNC(hdfsListDirectory);
BIND_HDFS_FUNC(hdfsFreeFileInfo);
BIND_HDFS_FUNC(hdfsDelete);
BIND_HDFS_FUNC(hdfsCreateDirectory);
BIND_HDFS_FUNC(hdfsGetPathInfo);
BIND_HDFS_FUNC(hdfsRename);
#undef BIND_HDFS_FUNC
};
// libhdfs.so won't be in the standard locations. Use the path as specified
// in the libhdfs documentation.
#if defined(_WIN32)
constexpr char kLibHdfsDso[] = "hdfs.dll";
#elif defined(__GNUC__) && (defined(__APPLE_CPP__) || defined(__APPLE_CC__) || \
defined(__MACOS_CLASSIC__))
constexpr char kLibHdfsDso[] = "libhdfs.dylib";
#else
constexpr char kLibHdfsDso[] = "libhdfs.so";
#endif
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
if (hdfs_home != nullptr) {
auto JoinPath = [](std::string home, std::string lib) {
if (home.back() != '/') home.push_back('/');
return home + "lib/native/" + lib;
};
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
TryLoadAndBind(path.c_str(), &handle_, status);
if (TF_GetCode(status) == TF_OK) {
return;
} else {
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
}
}
// Try to load the library dynamically in case it has been installed
// to a in non-standard location.
TryLoadAndBind(kLibHdfsDso, &handle_, status);
}
void* handle_;
};
// We rely on HDFS connection caching here. The HDFS client calls
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
// internally.
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
if (scheme == "file") {
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
} else if (scheme == "viewfs") {
char* defaultFS = nullptr;
libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS);
std::string defaultScheme, defaultCluster, defaultPath;
ParseHadoopPath(defaultFS, &defaultScheme, &defaultCluster, &defaultPath);
if (scheme != defaultScheme ||
(namenode.empty() && namenode != defaultCluster)) {
TF_SetStatus(status, TF_UNIMPLEMENTED,
"viewfs is only supported as a fs.defaultFS.");
return nullptr;
}
// The default NameNode configuration will be used (from the XML
// configuration files). See:
// https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259
libhdfs->hdfsBuilderSetNameNode(builder, "default");
} else if (scheme == "har") {
std::string path_har = path;
SplitArchiveNameAndPath(&path_har, &namenode, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
} else {
libhdfs->hdfsBuilderSetNameNode(
builder, namenode.empty() ? "default" : namenode.c_str());
}
auto fs = libhdfs->hdfsBuilderConnect(builder);
if (fs == nullptr)
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
else
TF_SetStatus(status, TF_OK, "");
return fs;
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
typedef struct HDFSFile {
std::string path;
std::string hdfs_path;
hdfsFS fs;
LibHDFS* libhdfs;
absl::Mutex mu;
hdfsFile handle ABSL_GUARDED_BY(mu);
HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs,
hdfsFile handle)
: path(std::move(path)),
hdfs_path(std::move(hdfs_path)),
fs(fs),
libhdfs(libhdfs),
mu(),
handle(handle) {}
} HDFSFile;
void Cleanup(TF_RandomAccessFile* file) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
{
absl::MutexLock l(&hdfs_file->mu);
if (hdfs_file->handle != nullptr) {
hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
}
}
delete hdfs_file;
}
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
auto libhdfs = hdfs_file->libhdfs;
auto fs = hdfs_file->fs;
auto hdfs_path = hdfs_file->hdfs_path.c_str();
auto path = hdfs_file->path.c_str();
char* dst = buffer;
bool eof_retried = false;
int64_t r = 0;
while (TF_GetCode(status) == TF_OK && !eof_retried) {
// We lock inside the loop rather than outside so we don't block other
// concurrent readers.
absl::MutexLock l(&hdfs_file->mu);
auto handle = hdfs_file->handle;
// Max read length is INT_MAX-2, for hdfsPread function take a parameter
// of int32. -2 offset can avoid JVM OutOfMemoryError.
size_t read_n =
(std::min)(n, static_cast<size_t>(std::numeric_limits<int>::max() - 2));
r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset), dst,
static_cast<tSize>(read_n));
if (r > 0) {
dst += r;
n -= r;
offset += r;
} else if (!eof_retried && r == 0) {
// Always reopen the file upon reaching EOF to see if there's more data.
// If writers are streaming contents while others are concurrently
// reading, HDFS requires that we reopen the file to see updated
// contents.
//
// Fixes #5438
if (handle != nullptr && libhdfs->hdfsCloseFile(fs, handle) != 0) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0);
if (handle == nullptr) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
eof_retried = true;
} else if (eof_retried && r == 0) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
} else if (errno == EINTR || errno == EAGAIN) {
// hdfsPread may return EINTR too. Just retry.
} else {
TF_SetStatusFromIOError(status, errno, path);
}
}
return r;
}
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct HDFSFile {
std::string hdfs_path;
hdfsFS fs;
LibHDFS* libhdfs;
hdfsFile handle;
HDFSFile(std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs, hdfsFile handle)
: hdfs_path(std::move(hdfs_path)),
fs(fs),
libhdfs(libhdfs),
handle(handle) {}
} HDFSFile;
static void Cleanup(TF_WritableFile* file) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
hdfs_file->fs = nullptr;
hdfs_file->handle = nullptr;
delete hdfs_file;
}
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
auto libhdfs = hdfs_file->libhdfs;
auto fs = hdfs_file->fs;
auto handle = hdfs_file->handle;
size_t cur_pos = 0, write_len = 0;
bool retry = false;
// max() - 2 can avoid OutOfMemoryError in JVM .
static const size_t max_len_once =
static_cast<size_t>(std::numeric_limits<tSize>::max() - 2);
while (cur_pos < n) {
write_len = (std::min)(n - cur_pos, max_len_once);
tSize w = libhdfs->hdfsWrite(fs, handle, buffer + cur_pos,
static_cast<tSize>(write_len));
if (w == -1) {
if (!retry && (errno == EINTR || errno == EAGAIN)) {
retry = true;
} else {
return TF_SetStatusFromIOError(status, errno,
hdfs_file->hdfs_path.c_str());
}
} else {
cur_pos += w;
}
}
TF_SetStatus(status, TF_OK, "");
}
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
int64_t position =
hdfs_file->libhdfs->hdfsTell(hdfs_file->fs, hdfs_file->handle);
if (position == -1)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
else
TF_SetStatus(status, TF_OK, "");
return position;
}
void Flush(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
if (hdfs_file->libhdfs->hdfsHFlush(hdfs_file->fs, hdfs_file->handle) != 0)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
else
TF_SetStatus(status, TF_OK, "");
}
void Sync(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
if (hdfs_file->libhdfs->hdfsHSync(hdfs_file->fs, hdfs_file->handle) != 0)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
else
TF_SetStatus(status, TF_OK, "");
}
void Close(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
TF_SetStatus(status, TF_OK, "");
if (hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle) != 0)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
hdfs_file->fs = nullptr;
hdfs_file->handle = nullptr;
}
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(vnvo2409): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_hadoop_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new LibHDFS(status);
if (TF_GetCode(status) != TF_OK) return;
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
delete libhdfs;
}
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_RDONLY, 0, 0, 0);
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
file->plugin_file =
new tf_random_access_file::HDFSFile(path, hdfs_path, fs, libhdfs, handle);
TF_SetStatus(status, TF_OK, "");
}
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(),
O_WRONLY | O_APPEND, 0, 0, 0);
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
file->plugin_file =
new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle);
TF_SetStatus(status, TF_OK, "");
}
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status) {
// hadoopReadZero() technically supports this call with the following
// caveats:
// - It only works up to 2 GB. We'd have to Stat() the file to ensure that
// it fits.
// - If not on the local filesystem, the entire file will be read, making
// it inefficient for callers that assume typical mmap() behavior.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"HDFS does not support ReadOnlyMemoryRegion");
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsExists(fs, hdfs_path.c_str()) == 0)
TF_SetStatus(status, TF_OK, "");
else
TF_SetStatus(status, TF_NOT_FOUND,
(std::string(path) + " not found").c_str());
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) return TF_SetStatusFromIOError(status, errno, path);
stats->length = static_cast<int64_t>(info->mSize);
stats->mtime_nsec = static_cast<int64_t>(info->mLastMod) * 1e9;
stats->is_directory = info->mKind == kObjectKindDirectory;
libhdfs->hdfsFreeFileInfo(info, 1);
TF_SetStatus(status, TF_OK, "");
}
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
TF_SetStatus(status, TF_OK, "");
auto size = static_cast<int64_t>(info->mSize);
libhdfs->hdfsFreeFileInfo(info, 1);
return size;
}
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/0) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsCreateDirectory(fs, hdfs_path.c_str()) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// Count the number of entries in the directory, and only delete if it's
// non-empty. This is consistent with the interface, but note that there's
// a race condition where a file may be added after this check, in which
// case the directory will still be deleted.
int entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &entries);
if (info != nullptr) libhdfs->hdfsFreeFileInfo(info, entries);
// Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty
// folder, especially for Kerberos enable setup, EAGAIN is quite common when
// the call is actually successful. Check again by Stat.
if (info == nullptr && errno != 0) {
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return;
}
if (entries > 0)
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/1) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
ParseHadoopPath(src, &scheme, &namenode, &hdfs_path_src);
ParseHadoopPath(dst, &scheme, &namenode, &hdfs_path_dst);
if (libhdfs->hdfsExists(fs, hdfs_path_dst.c_str()) == 0 &&
libhdfs->hdfsDelete(fs, hdfs_path_dst.c_str(), /*recursive=*/0) != 0)
return TF_SetStatusFromIOError(status, errno, dst);
if (libhdfs->hdfsRename(fs, hdfs_path_src.c_str(), hdfs_path_dst.c_str()) !=
0)
TF_SetStatusFromIOError(status, errno, src);
else
TF_SetStatus(status, TF_OK, "");
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// hdfsListDirectory returns nullptr if the directory is empty. Do a separate
// check to verify the directory exists first.
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &num_entries);
if (info == nullptr) {
if (stat.is_directory) {
// Assume it's an empty directory.
TF_SetStatus(status, TF_OK, "");
return 0;
}
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
auto BaseName = [](const std::string& name) {
return name.substr(name.find_last_of('/') + 1);
};
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(BaseName(info[i].mName).c_str());
}
libhdfs->hdfsFreeFileInfo(info, num_entries);
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
// TODO(vnvo2409): Implement later
} // namespace tf_hadoop_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 3;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "hdfs");
ProvideFilesystemSupportFor(&info->ops[1], "viewfs");
ProvideFilesystemSupportFor(&info->ops[2], "har");
}

View File

@ -0,0 +1,21 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_

View File

@ -0,0 +1,63 @@
# Experimental s3 filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for S3 environments
tf_cc_shared_object(
name = "s3_filesystem",
framework_so = [],
linkstatic = False,
per_os_targets = 1,
visibility = ["//visibility:public"],
deps = [":s3_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "s3_filesystem_impl",
srcs = ["s3_filesystem.cc"],
hdrs = ["s3_filesystem.h"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":aws_crypto",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@aws",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "aws_crypto",
srcs = ["aws_crypto.cc"],
hdrs = ["aws_crypto.h"],
deps = [
"@aws",
"@boringssl//:crypto",
],
alwayslink = 1,
)
tf_cc_test(
name = "s3_filesystem_test",
srcs = [
"s3_filesystem_test.cc",
],
tags = [
"manual",
"notap",
],
deps = [
":s3_filesystem_impl",
"//tensorflow/core/platform:path",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
],
)

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
#include <aws/core/utils/crypto/HashResult.h>
#include <aws/s3/S3Client.h>
#include <openssl/hmac.h>
#include <openssl/rand.h>
#include <openssl/sha.h>
namespace tf_s3_filesystem {
class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
public:
AWSSha256HMACOpenSSLImpl() {}
virtual ~AWSSha256HMACOpenSSLImpl() = default;
Aws::Utils::Crypto::HashResult Calculate(
const Aws::Utils::ByteBuffer& toSign,
const Aws::Utils::ByteBuffer& secret) override {
unsigned int length = SHA256_DIGEST_LENGTH;
Aws::Utils::ByteBuffer digest(length);
memset(digest.GetUnderlyingData(), 0, length);
HMAC_CTX ctx;
HMAC_CTX_init(&ctx);
HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
HMAC_CTX_cleanup(&ctx);
return Aws::Utils::Crypto::HashResult(std::move(digest));
}
};
class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
public:
AWSSha256OpenSSLImpl() {}
virtual ~AWSSha256OpenSSLImpl() = default;
Aws::Utils::Crypto::HashResult Calculate(const Aws::String& str) override {
SHA256_CTX sha256;
SHA256_Init(&sha256);
SHA256_Update(&sha256, str.data(), str.size());
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
SHA256_Final(hash.GetUnderlyingData(), &sha256);
return Aws::Utils::Crypto::HashResult(std::move(hash));
}
Aws::Utils::Crypto::HashResult Calculate(Aws::IStream& stream) override {
SHA256_CTX sha256;
SHA256_Init(&sha256);
auto currentPos = stream.tellg();
if (currentPos == std::streampos(std::streamoff(-1))) {
currentPos = 0;
stream.clear();
}
stream.seekg(0, stream.beg);
char streamBuffer
[Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
while (stream.good()) {
stream.read(streamBuffer,
Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
auto bytesRead = stream.gcount();
if (bytesRead > 0) {
SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
}
}
stream.clear();
stream.seekg(currentPos, stream.beg);
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
SHA256_Final(hash.GetUnderlyingData(), &sha256);
return Aws::Utils::Crypto::HashResult(std::move(hash));
}
};
class AWSSecureRandomBytesImpl : public Aws::Utils::Crypto::SecureRandomBytes {
public:
AWSSecureRandomBytesImpl() {}
virtual ~AWSSecureRandomBytesImpl() = default;
void GetBytes(unsigned char* buffer, size_t bufferSize) override {
assert(buffer);
int success = RAND_bytes(buffer, static_cast<int>(bufferSize));
if (success != 1) {
m_failure = true;
}
}
private:
bool m_failure;
};
std::shared_ptr<Aws::Utils::Crypto::Hash>
AWSSHA256Factory::CreateImplementation() const {
return Aws::MakeShared<AWSSha256OpenSSLImpl>(AWSCryptoAllocationTag);
}
std::shared_ptr<Aws::Utils::Crypto::HMAC>
AWSSHA256HmacFactory::CreateImplementation() const {
return Aws::MakeShared<AWSSha256HMACOpenSSLImpl>(AWSCryptoAllocationTag);
}
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes>
AWSSecureRandomFactory::CreateImplementation() const {
return Aws::MakeShared<AWSSecureRandomBytesImpl>(AWSCryptoAllocationTag);
}
} // namespace tf_s3_filesystem

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
#include <aws/core/Aws.h>
#include <aws/core/utils/crypto/Factories.h>
#include <aws/core/utils/crypto/HMAC.h>
#include <aws/core/utils/crypto/Hash.h>
#include <aws/core/utils/crypto/SecureRandom.h>
namespace tf_s3_filesystem {
constexpr char AWSCryptoAllocationTag[] = "AWSCryptoAllocation";
class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
const override;
};
class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
const override;
};
class AWSSecureRandomFactory : public Aws::Utils::Crypto::SecureRandomFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes> CreateImplementation()
const override;
};
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,101 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
#include <aws/core/Aws.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/memory/stl/AWSMap.h>
#include <aws/core/utils/threading/Executor.h>
#include <aws/s3/S3Client.h>
#include <aws/transfer/TransferManager.h>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseS3Path(const Aws::String& fname, bool object_empty_ok,
Aws::String* bucket, Aws::String* object, TF_Status* status);
namespace tf_random_access_file {
void Cleanup(TF_RandomAccessFile* file);
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status);
} // namespace tf_random_access_file
namespace tf_writable_file {
void Cleanup(TF_WritableFile* file);
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status);
int64_t Tell(const TF_WritableFile* file, TF_Status* status);
void Sync(const TF_WritableFile* file, TF_Status* status);
void Flush(const TF_WritableFile* file, TF_Status* status);
void Close(const TF_WritableFile* file, TF_Status* status);
} // namespace tf_writable_file
namespace tf_read_only_memory_region {
void Cleanup(TF_ReadOnlyMemoryRegion* region);
const void* Data(const TF_ReadOnlyMemoryRegion* region);
uint64_t Length(const TF_ReadOnlyMemoryRegion* region);
} // namespace tf_read_only_memory_region
namespace tf_s3_filesystem {
typedef struct S3File {
std::shared_ptr<Aws::S3::S3Client> s3_client;
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor;
// We need 2 `TransferManager`, for multipart upload/download.
Aws::Map<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>
transfer_managers;
// Sizes to split objects during multipart upload/download.
Aws::Map<Aws::Transfer::TransferDirection, uint64_t> multi_part_chunk_sizes;
bool use_multi_part_download;
absl::Mutex initialization_lock;
S3File();
} S3File;
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status);
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status);
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status);
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_Status* status);
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status);
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_

View File

@ -0,0 +1,540 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h"
#include <fstream>
#include <random>
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
static std::string InitializeTmpDir() {
// This env should be something like `s3://bucket/path`
const char* test_dir = getenv("S3_TEST_TMPDIR");
if (test_dir != nullptr) {
Aws::String bucket, object;
TF_Status* status = TF_NewStatus();
ParseS3Path(test_dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
return "";
}
TF_DeleteStatus(status);
// We add a random value into `test_dir` to ensures that two consecutive
// runs are unlikely to clash.
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> distribution;
std::string rng_val = std::to_string(distribution(gen));
return tensorflow::io::JoinPath(std::string(test_dir), rng_val);
} else {
return "";
}
}
static std::string GetLocalLargeFile() {
// This env is used when we want to test against a large file ( ~ 50MB ).
// `S3_TEST_LOCAL_LARGE_FILE` and `S3_TEST_SERVER_LARGE_FILE` must be the same
// file.
static std::string path;
if (path.empty()) {
const char* env = getenv("S3_TEST_LOCAL_LARGE_FILE");
if (env == nullptr) return "";
path = env;
}
return path;
}
static std::string GetServerLargeFile() {
// This env is used when we want to test against a large file ( ~ 50MB ).
// `S3_TEST_LOCAL_LARGE_FILE` and `S3_TEST_SERVER_LARGE_FILE` must be the same
// file.
static std::string path;
if (path.empty()) {
const char* env = getenv("S3_TEST_SERVER_LARGE_FILE");
if (env == nullptr) return "";
Aws::String bucket, object;
TF_Status* status = TF_NewStatus();
ParseS3Path(env, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
return "";
}
TF_DeleteStatus(status);
path = env;
}
return path;
}
static std::string* GetTmpDir() {
static std::string tmp_dir = InitializeTmpDir();
if (tmp_dir == "")
return nullptr;
else
return &tmp_dir;
}
namespace tensorflow {
namespace {
class S3FilesystemTest : public ::testing::Test {
public:
void SetUp() override {
root_dir_ = io::JoinPath(
*GetTmpDir(),
::testing::UnitTest::GetInstance()->current_test_info()->name());
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_s3_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
}
void TearDown() override {
TF_DeleteStatus(status_);
tf_s3_filesystem::Cleanup(filesystem_);
delete filesystem_;
}
std::string GetURIForPath(const std::string& path) {
const std::string translated_name =
tensorflow::io::JoinPath(root_dir_, path);
return translated_name;
}
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
GetWriter() {
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
new TF_WritableFile, [](TF_WritableFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
delete file;
}
});
writer->plugin_file = nullptr;
return writer;
}
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
GetReader() {
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
reader->plugin_file = nullptr;
return reader;
}
void WriteString(const std::string& path, const std::string& content) {
auto writer = GetWriter();
tf_s3_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Append(writer.get(), content.c_str(), content.length(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Close(writer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return;
}
std::string ReadAll(const string& path) {
auto reader = GetReader();
tf_s3_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
auto file_size =
tf_s3_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
std::string content;
content.resize(file_size);
auto read = tf_random_access_file::Read(reader.get(), 0, file_size,
&content[0], status_);
if (TF_GetCode(status_) != TF_OK) return "";
if (read >= 0) content.resize(read);
if (file_size != content.size())
TF_SetStatus(
status_, TF_DATA_LOSS,
std::string("expected " + std::to_string(file_size) + " got " +
std::to_string(content.size()) + " bytes")
.c_str());
return content;
}
std::string ReadAllInChunks(const string& path, size_t buffer_size,
bool use_multi_part_download) {
auto reader = GetReader();
auto s3_file =
static_cast<tf_s3_filesystem::S3File*>(filesystem_->plugin_filesystem);
s3_file->use_multi_part_download = use_multi_part_download;
s3_file
->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] =
buffer_size;
tf_s3_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
auto file_size =
tf_s3_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
std::size_t part_count = (std::max)(
static_cast<size_t>((file_size + buffer_size - 1) / buffer_size),
static_cast<size_t>(1));
std::unique_ptr<char[]> buffer{new char[buffer_size]};
std::stringstream ss;
uint64_t offset = 0;
uint64_t server_size = 0;
for (size_t i = 0; i < part_count; i++) {
offset = i * buffer_size;
buffer_size =
(i == part_count - 1) ? file_size - server_size : buffer_size;
auto read = tf_random_access_file::Read(reader.get(), offset, buffer_size,
buffer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
if (read > 0) {
ss.write(buffer.get(), read);
server_size += static_cast<uint64_t>(read);
}
if (server_size == file_size) break;
if (read != buffer_size) {
if (read == 0)
TF_SetStatus(status_, TF_OUT_OF_RANGE, "eof");
else
TF_SetStatus(
status_, TF_DATA_LOSS,
("truncated record at " + std::to_string(offset)).c_str());
return "";
}
}
if (file_size != server_size) {
TF_SetStatus(status_, TF_DATA_LOSS,
std::string("expected " + std::to_string(file_size) +
" got " + std::to_string(server_size) + " bytes")
.c_str());
return "";
}
TF_SetStatus(status_, TF_OK, "");
return ss.str();
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
private:
std::string root_dir_;
};
TEST_F(S3FilesystemTest, NewRandomAccessFile) {
const std::string path = GetURIForPath("RandomAccessFile");
const std::string content = "abcdefghijklmn";
WriteString(path, content);
ASSERT_TF_OK(status_);
auto reader = GetReader();
tf_s3_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), reader.get(),
status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content.size(), result.size());
EXPECT_EQ(content, result);
result.clear();
result.resize(4);
read = tf_random_access_file::Read(reader.get(), 2, 4, &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, result.size());
EXPECT_EQ(content.substr(2, 4), result);
}
TEST_F(S3FilesystemTest, NewWritableFile) {
auto writer = GetWriter();
const std::string path = GetURIForPath("WritableFile");
tf_s3_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Append(writer.get(), "content1,", strlen("content1,"),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Append(writer.get(), "content2", strlen("content2"),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Flush(writer.get(), status_);
EXPECT_TF_OK(status_);
tf_writable_file::Sync(writer.get(), status_);
EXPECT_TF_OK(status_);
tf_writable_file::Close(writer.get(), status_);
EXPECT_TF_OK(status_);
auto content = ReadAll(path);
EXPECT_TF_OK(status_);
EXPECT_EQ("content1,content2", content);
}
TEST_F(S3FilesystemTest, NewAppendableFile) {
const std::string path = GetURIForPath("AppendableFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
auto writer = GetWriter();
tf_s3_filesystem::NewAppendableFile(filesystem_, path.c_str(), writer.get(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Append(writer.get(), "content", strlen("content"), status_);
EXPECT_TF_OK(status_);
tf_writable_file::Close(writer.get(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(S3FilesystemTest, NewReadOnlyMemoryRegionFromFile) {
const std::string path = GetURIForPath("MemoryFile");
const std::string content = "content";
WriteString(path, content);
ASSERT_TF_OK(status_);
std::unique_ptr<TF_ReadOnlyMemoryRegion,
void (*)(TF_ReadOnlyMemoryRegion * file)>
region(new TF_ReadOnlyMemoryRegion, [](TF_ReadOnlyMemoryRegion* file) {
if (file != nullptr) {
if (file->plugin_memory_region != nullptr)
tf_read_only_memory_region::Cleanup(file);
delete file;
}
});
region->plugin_memory_region = nullptr;
tf_s3_filesystem::NewReadOnlyMemoryRegionFromFile(filesystem_, path.c_str(),
region.get(), status_);
EXPECT_TF_OK(status_);
std::string result(reinterpret_cast<const char*>(
tf_read_only_memory_region::Data(region.get())),
tf_read_only_memory_region::Length(region.get()));
EXPECT_EQ(content, result);
}
TEST_F(S3FilesystemTest, PathExists) {
const std::string path = GetURIForPath("PathExists");
tf_s3_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_);
TF_SetStatus(status_, TF_OK, "");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_s3_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(S3FilesystemTest, GetChildren) {
const std::string base = GetURIForPath("GetChildren");
tf_s3_filesystem::CreateDir(filesystem_, base.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string file = io::JoinPath(base, "TestFile.csv");
WriteString(file, "test");
EXPECT_TF_OK(status_);
const std::string subdir = io::JoinPath(base, "SubDir");
tf_s3_filesystem::CreateDir(filesystem_, subdir.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv");
WriteString(subfile, "test");
EXPECT_TF_OK(status_);
char** entries;
auto num_entries = tf_s3_filesystem::GetChildren(filesystem_, base.c_str(),
&entries, status_);
EXPECT_TF_OK(status_);
std::vector<std::string> childrens;
for (int i = 0; i < num_entries; ++i) {
childrens.push_back(entries[i]);
}
std::sort(childrens.begin(), childrens.end());
EXPECT_EQ(std::vector<string>({"SubDir", "TestFile.csv"}), childrens);
}
TEST_F(S3FilesystemTest, DeleteFile) {
const std::string path = GetURIForPath("DeleteFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_s3_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(S3FilesystemTest, CreateDir) {
// s3 object storage doesn't support empty directory, we create file in the
// directory
const std::string dir = GetURIForPath("CreateDir");
tf_s3_filesystem::CreateDir(filesystem_, dir.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string file = io::JoinPath(dir, "CreateDirFile.csv");
WriteString(file, "test");
ASSERT_TF_OK(status_);
TF_FileStatistics stat;
tf_s3_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_TRUE(stat.is_directory);
}
TEST_F(S3FilesystemTest, DeleteDir) {
// s3 object storage doesn't support empty directory, we create file in the
// directory
const std::string dir = GetURIForPath("DeleteDir");
const std::string file = io::JoinPath(dir, "DeleteDirFile.csv");
WriteString(file, "test");
ASSERT_TF_OK(status_);
tf_s3_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
EXPECT_NE(TF_GetCode(status_), TF_OK);
TF_SetStatus(status_, TF_OK, "");
tf_s3_filesystem::DeleteFile(filesystem_, file.c_str(), status_);
EXPECT_TF_OK(status_);
tf_s3_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_s3_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
}
TEST_F(S3FilesystemTest, StatFile) {
const std::string path = GetURIForPath("StatFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
TF_FileStatistics stat;
tf_s3_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, stat.length);
EXPECT_FALSE(stat.is_directory);
}
TEST_F(S3FilesystemTest, SimpleCopyFile) {
const std::string src = GetURIForPath("SimpleCopySrc");
const std::string dst = GetURIForPath("SimpleCopyDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_s3_filesystem::CopyFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ(result, "test");
}
TEST_F(S3FilesystemTest, RenameFile) {
const std::string src = GetURIForPath("RenameFileSrc");
const std::string dst = GetURIForPath("RenameFileDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_s3_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test", result);
}
TEST_F(S3FilesystemTest, RenameFileOverwrite) {
const std::string src = GetURIForPath("RenameFileOverwriteSrc");
const std::string dst = GetURIForPath("RenameFileOverwriteDst");
WriteString(src, "test_old");
ASSERT_TF_OK(status_);
WriteString(dst, "test_new");
ASSERT_TF_OK(status_);
tf_s3_filesystem::PathExists(filesystem_, dst.c_str(), status_);
EXPECT_TF_OK(status_);
tf_s3_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test_old", result);
}
// Test against large file.
TEST_F(S3FilesystemTest, ReadLargeFile) {
auto local_path = GetLocalLargeFile();
auto server_path = GetServerLargeFile();
if (local_path.empty() || server_path.empty()) GTEST_SKIP();
std::ifstream in(local_path, std::ios::binary);
std::string local_content((std::istreambuf_iterator<char>(in)),
std::istreambuf_iterator<char>());
constexpr size_t buffer_size = 50 * 1024 * 1024;
auto server_content = ReadAllInChunks(server_path, buffer_size, true);
ASSERT_TF_OK(status_);
EXPECT_EQ(local_content, server_content);
server_content = ReadAllInChunks(server_path, buffer_size, false);
ASSERT_TF_OK(status_);
EXPECT_EQ(local_content, server_content);
}
TEST_F(S3FilesystemTest, CopyLargeFile) {
auto server_path = GetServerLargeFile();
if (server_path.empty()) GTEST_SKIP();
auto path = GetURIForPath("CopyLargeFile");
constexpr size_t buffer_size = 5 * 1024 * 1024;
auto s3_file =
static_cast<tf_s3_filesystem::S3File*>(filesystem_->plugin_filesystem);
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] =
buffer_size;
tf_s3_filesystem::CopyFile(filesystem_, server_path.c_str(), path.c_str(),
status_);
EXPECT_TF_OK(status_);
auto server_size =
tf_s3_filesystem::GetFileSize(filesystem_, server_path.c_str(), status_);
EXPECT_TF_OK(status_);
auto actual_size =
tf_s3_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(server_size, actual_size);
}
} // namespace
} // namespace tensorflow
GTEST_API_ int main(int argc, char** argv) {
tensorflow::testing::InstallStacktraceHandler();
if (!GetTmpDir()) {
std::cerr << "Could not read S3_TEST_TMPDIR env";
return -1;
}
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,24 @@
# Library of gradient functions.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "math_grad",
srcs = ["math_grad.cc"],
hdrs = [
"math_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -0,0 +1,86 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::Mul;
namespace tensorflow {
namespace gradients {
namespace {
class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
};
class ExpGradientFunction : public GradientFunction {
public:
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
exp->Ref();
}
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
vector<AbstractTensorHandle*> conj_outputs(1);
TF_RETURN_IF_ERROR(
Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj"));
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
absl::MakeSpan(*grad_outputs), "ExpGradMul"));
return Status::OK();
}
~ExpGradientFunction() override {}
private:
AbstractTensorHandlePtr exp_;
};
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
}
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
return new ExpGradientFunction(op.outputs[0]);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
GradientFunction* ExpRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -1,166 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <memory>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
using tensorflow::ServerFactory;
namespace tensorflow {
/* static */ Status CGrpcServer::Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server) {
auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function,
join_function, delete_function);
GrpcServerOptions options;
options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) {
return new CRendezvousMgr(env, rendezvous_builder);
};
TF_RETURN_IF_ERROR(grpc_server->Init(options));
TF_Status* tf_status = TF_NewStatus();
grpc_server->SetContext(init_function(
reinterpret_cast<const TF_GrpcServer*>(grpc_server), tf_status));
TF_RETURN_IF_ERROR(tf_status->status);
TF_DeleteStatus(tf_status);
out_server->reset(grpc_server);
return Status::OK();
}
Status CGrpcServer::Start() {
Status status = GrpcServer::Start();
TF_Status* tf_status = TF_NewStatus();
(*start_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Stop() {
Status status = GrpcServer::Stop();
TF_Status* tf_status = TF_NewStatus();
(*stop_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Join() {
Status status = GrpcServer::Join();
TF_Status* tf_status = TF_NewStatus();
(*join_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
namespace {
// Factory that creates CGrpcServer instances.
class CServerFactory : public ServerFactory {
public:
CServerFactory(bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*,
TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder)
: accept_function_(accept_function),
init_function_(init_function),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def, const Options& options,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,
join_function_, delete_function_, rendezvous_builder_, out_server));
return Status::OK();
}
// Returns true if and only if this factory can create a server
// based on the given `server_def`.
bool AcceptsOptions(const ServerDef& server_def) override {
return (*accept_function_)(server_def.protocol().c_str());
}
private:
bool (*accept_function_)(const char* protocol);
void* (*init_function_)(const TF_GrpcServer*, TF_Status*);
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
TF_RemoteRendezvousBuilder* rendezvous_builder_;
};
} // namespace
} // namespace tensorflow
// Server factory representation to use in C API.
// Holds CServerFactory pointer.
struct TF_GrpcServerFactory {
::tensorflow::CServerFactory* factory;
};
TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder) {
TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory;
server_factory->factory = new ::tensorflow::CServerFactory(
accept_function, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
return server_factory;
}
void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) {
DCHECK_NE(server_factory, nullptr);
delete server_factory;
}
void TF_RegisterGrpcServerFactory(const char* server_type,
TF_GrpcServerFactory* server_factory) {
ServerFactory::Register(server_type, server_factory->factory);
}

View File

@ -1,97 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for TensorFlow Networking.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Users wishing to register a custom GrpcServer should call
// TF_NewServerFactory and then TF_RegisterGrpcServerFactory.
//
// Example:
// ```c++
// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
// rendezvous_init_function,
// receive_from_remote_async_function,
// rendezvous_delete_function);
//
// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
// accept_function,
// init_function,
// start_function,
// stop_function,
// join_function,
// delete_function,
// rendezvous_builder);
// TF_RegisterGrpcServerFactory("customfactory", factory);
// ...
// TF_DeleteGrpcServerFactory(factory);
// ```
typedef struct TF_GrpcServerFactory TF_GrpcServerFactory;
typedef struct TF_GrpcServerOptions TF_GrpcServerOptions;
typedef struct TF_GrpcServer TF_GrpcServer;
typedef struct TF_ServerContext {
TF_GrpcServer* const server;
void* context;
} TF_ServerContext;
// Creates a new TF_GrpcServerFactory instance. Caller takes ownership
// of TF_GrpcServerFactory instance and should deallocate it by calling
// TF_GrpcDeleteServerFactory.
// accept_function should return true if this ServerFactory can create
// server instances for the given protocol name (for e.g. grpc+verbs).
// GRPC servers created by this factory will call provided
// init_function, start_function, stop_function, join_function and
// delete_function.
//
// Note that clean shutdown is currently not implemented for GrpcServer.
// So, stop_function will never be called now but may be in the future
// when stop mechanism is supported.
TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Deletes TF_GrpcServerFactory instances.
// Note that this function only deletes TF_GrpcServerFactory wrapper.
// Actual underlying server factory would not be deleted and will
// remain registered.
TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory(
TF_GrpcServerFactory* server_factory);
// Registers provided server_factory for the given server_type.
// server_type must be unique to the server factory.
TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory(
const char* server_type, TF_GrpcServerFactory* server_factory);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_

View File

@ -1,77 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
// GrpcServer implementation that forwards calls to callbacks.
class CGrpcServer : public GrpcServer {
protected:
CGrpcServer(const ServerDef& server_def,
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*))
: GrpcServer(server_def, ::tensorflow::Env::Default()),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
context_(nullptr) {}
public:
static Status Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server);
Status Start() override;
Status Stop() override;
Status Join() override;
~CGrpcServer() override { delete_function_(context_); }
protected:
void SetContext(void* context) { context_ = context; }
private:
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
void* context_;
friend class NetworksTest;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_

View File

@ -1,256 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
bool accept_functionA(const char* protocol_name) {
return strcmp(protocol_name, "grpc+A") == 0;
}
bool accept_functionB(const char* protocol_name) {
return strcmp(protocol_name, "grpc+B") == 0;
}
struct SomeServerData {
bool server_started = false;
};
struct SomeRendezvousData {
int test = 0;
};
void* init_function(const TF_GrpcServer* server, TF_Status* status) {
SomeServerData* server_data = new SomeServerData();
TF_SetStatus(status, TF_OK, "");
return server_data;
}
void start_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
auto* server_data = static_cast<SomeServerData*>(context);
server_data->server_started = true;
TF_SetStatus(status, TF_OK, "");
}
void stop_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void join_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void delete_function(void* context) {
auto* server_data = static_cast<SomeServerData*>(context);
delete server_data;
}
void* rendezvous_init_function(void* server_context) {
return new SomeRendezvousData();
}
void Deallocator(void* data, size_t, void* arg) {
tensorflow::cpu_allocator()->DeallocateRaw(data);
*reinterpret_cast<bool*>(arg) = true;
}
void receive_from_remote_async_function(TF_ParsedKey* key,
TF_RendezvousArgs* args,
TF_RendezvousDoneCallback* callback,
void* context) {
// Create dummy tensor
const int num_bytes = 6 * sizeof(float);
float* values =
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
EIGEN_MAX_ALIGN_BYTES, num_bytes));
int64_t dims[] = {2, 3};
bool deallocator_called = false;
auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
&Deallocator, &deallocator_called);
callback->tensor = tensor;
auto* tf_status = TF_NewStatus();
TF_SetStatus(tf_status, TF_OK, "");
callback->status = tf_status;
TF_RendezvousDone(callback);
TF_DeleteStatus(tf_status);
TF_DeleteTensor(tensor);
}
void rendezvous_delete_function(void* context) {
auto* rendezvous_data = static_cast<SomeRendezvousData*>(context);
delete rendezvous_data;
}
tensorflow::ServerDef GetServerDef(const string& protocol,
const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol(protocol);
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
class NetworksTest : public ::testing::Test {
public:
~NetworksTest() override {}
SomeServerData* GetServerData(CGrpcServer* server) {
EXPECT_NE(server->context_, nullptr);
return static_cast<SomeServerData*>(server->context_);
}
};
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
const string& receiver, const string& name) {
Rendezvous::ParsedKey result;
CHECK(
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
name, FrameAndIter(0, 0)),
&result)
.ok());
return result;
}
void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def,
RemoteRendezvous* remote_rendezvous) {
int rendezvous_id = 0;
auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id);
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, *server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get()));
}
TEST_F(NetworksTest, TestStartServer) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionA, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryA", factory);
ServerDef server_def = GetServerDef("grpc+A", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
auto* server_data = GetServerData(grpc_server);
ASSERT_FALSE(server_data->server_started);
TF_EXPECT_OK(server->Start());
ASSERT_TRUE(server_data->server_started);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// TODO(annarev): find a clean way to shutdown server.
server.release();
}
TEST_F(NetworksTest, TestReceiveData) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionB, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryB", factory);
ServerDef server_def = GetServerDef("grpc+B", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
TF_EXPECT_OK(server->Start());
auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr;
auto* remote_rendezvous = rendezvous_mgr->Find(0);
auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1,
"/job:localhost/replica:0/task:0/device:CPU:0", "test");
Rendezvous::Args args;
bool done_callback_called = false;
auto* done_callback_called_ptr = &done_callback_called;
absl::Notification notification;
auto* notification_ptr = &notification;
InitializeRendezvous(grpc_server, &server_def, remote_rendezvous);
remote_rendezvous->RecvAsync(
key, args,
[done_callback_called_ptr, notification_ptr](
const Status&, const Rendezvous::Args&, const Rendezvous::Args&,
const Tensor&, const bool) mutable {
*done_callback_called_ptr = true;
notification_ptr->Notify();
});
notification.WaitForNotificationWithTimeout(absl::Seconds(10));
ASSERT_EQ(done_callback_called, true);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// Server doesn't have a clean shutdown.
server.release();
}
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
# Experimental ops. These will eventually be replaced by machine-generated versions.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "array_ops",
srcs = [
"array_ops.cc",
],
hdrs = [
"array_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "math_ops",
srcs = [
"math_ops.cc",
],
hdrs = [
"math_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
return identity_op->Execute(outputs, &num_retvals);
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace ops {
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
using tensorflow::tracing::TracingOperation;
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
int num_retvals = 1;
return mul_op->Execute(outputs, &num_retvals);
}
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
auto dtype = inputs[0]->DataType();
if (DataTypeIsFloating(BaseType(dtype)) ||
DataTypeIsInteger(BaseType(dtype))) {
TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
} else {
return errors::Unimplemented("Conj does not support complex types yet.");
}
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
namespace tensorflow {
namespace ops {
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_

View File

@ -1,127 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/rendezvous.h"
#include <functional>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
namespace tensorflow {
CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context),
void* server_context)
: BaseRemoteRendezvous(env, step_id),
receive_from_remote_async_function_(receive_from_remote_async_function),
delete_function_(delete_function),
context_(nullptr) {}
void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) {
if (args.cancellation_manager != nullptr) {
VLOG(1) << "WARNING: CRemoteRendezvous does not support cancellation.";
}
TF_ParsedKey key;
key.src_device = parsed.src_device.data();
key.src_device_len = parsed.src_device.size();
key.dst_device = parsed.dst_device.data();
key.dst_device_len = parsed.dst_device.size();
key.full_key = parsed.FullKey().data();
key.full_key_len = parsed.FullKey().size();
TF_DeviceContext* device_context = new TF_DeviceContext();
device_context->context = args.device_context;
TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes();
alloc_attrs->value = args.alloc_attrs.value;
alloc_attrs->scope_id = args.alloc_attrs.scope_id;
alloc_attrs->on_host = args.alloc_attrs.on_host();
alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible();
TF_RendezvousArgs* cargs = new TF_RendezvousArgs();
cargs->device_context = device_context;
cargs->alloc_attrs = alloc_attrs;
TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback();
done_callback->done_callback = done;
done_callback->recv_args = cargs;
receive_from_remote_async_function_(&key, cargs, done_callback, context_);
}
CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); }
} // namespace tensorflow
TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context)) {
TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder();
builder->init_function = init_function;
builder->delete_function = delete_function;
builder->receive_from_remote_async_function =
receive_from_remote_async_function;
return builder;
}
void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder) {
DCHECK_NE(rendezvous_builder, nullptr);
delete rendezvous_builder;
}
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback) {
DCHECK_NE(callback, nullptr);
::tensorflow::Tensor tensor;
TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor));
::tensorflow::Rendezvous::Args recv_args;
recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value;
recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id;
recv_args.device_context = callback->recv_args->device_context->context;
::tensorflow::Rendezvous::Args sent_args;
callback->done_callback(callback->status->status, sent_args, recv_args,
tensor, callback->dead);
if (callback->recv_args) {
DCHECK_NE(callback->recv_args, nullptr);
DCHECK_NE(callback->recv_args->alloc_attrs, nullptr);
DCHECK_NE(callback->recv_args->device_context, nullptr);
delete callback->recv_args->alloc_attrs;
delete callback->recv_args->device_context;
delete callback->recv_args;
}
delete callback;
callback = nullptr;
}

View File

@ -1,67 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#include "tensorflow/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for Rendezvous.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Custom rendezvous allows for custom implementations of Recv call.
//
// Users wishing to create custom rendezvous objects should call
// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder
// to to TF_NewServerFactory.
typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder;
typedef struct TF_ParsedKey TF_ParsedKey;
typedef struct TF_RendezvousArgs TF_RendezvousArgs;
typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback;
// Creates a new TF_RemoteRendezvousBuilder instance.
// Rendezvous instances will forward calls to init_function,
// receive_from_remote_async_function and delete_function passed here.
//
// Note that receive_from_remote_async_function implementation must call
// TF_Done with the TF_DoneCallback passed as an argument.
TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context));
// Deletes TF_RemoteRendezvousBuilder instances.
TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Calls TF_DoneCallback and destroys callback instance and
// TF_DoneCallback members except `tensor` and `status`. Caller is
// responsible for deleting `tensor` and `status` after TF_Done returns.
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_

View File

@ -1,135 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#include <stddef.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/platform/macros.h"
struct TF_ParsedKey {
// char* members might not be null-terminated.
const char* src_device;
size_t src_device_len;
const char* dst_device;
size_t dst_device_len;
const char* full_key;
size_t full_key_len;
};
struct TF_AllocatorAttributes {
bool on_host;
bool nic_compatible;
// NOTE: The upper 8 bits of the value are reserved for
// device-specific uses. Implementors of a device can interpret these
// upper 8 bits in device-specific ways, and ops implemented for those
// devices are responsible for setting those 8 bits appropriately.
tensorflow::uint32 value = 0;
// EXPERIMENTAL: If this is greater than zero, then allocation is delegated to
// a named special-purpose allocator on the same device.
tensorflow::int32 scope_id = 0;
};
struct TF_DeviceContext {
::tensorflow::DeviceContext* context;
};
struct TF_RendezvousArgs {
const TF_DeviceContext* device_context;
const TF_AllocatorAttributes* alloc_attrs;
};
struct TF_RendezvousDoneCallback {
::tensorflow::Rendezvous::DoneCallback done_callback;
// TODO(annarev): figure out if we should also support sent_args.
const TF_RendezvousArgs* recv_args;
TF_Tensor* tensor = nullptr;
TF_Status* status;
bool dead;
};
struct TF_RemoteRendezvousBuilder {
void* (*init_function)(void* server_context);
void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function)(void* context);
void* server_context;
};
namespace tensorflow {
class CRemoteRendezvous : public BaseRemoteRendezvous {
public:
CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*, void* context),
void (*delete_function)(void* context),
void* server_context);
void SetContext(void* context) { context_ = context; }
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) override;
private:
~CRemoteRendezvous() override;
void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function_)(void* context);
void* context_;
TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous);
};
class CRendezvousMgr : public BaseRendezvousMgr {
public:
CRendezvousMgr(const WorkerEnv* env,
const TF_RemoteRendezvousBuilder* rendezvous_builder)
: BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {}
protected:
BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env) override {
auto* rendezvous = new CRemoteRendezvous(
worker_env, step_id,
rendezvous_builder_->receive_from_remote_async_function,
rendezvous_builder_->delete_function,
rendezvous_builder_->server_context);
rendezvous->SetContext(rendezvous_builder_->init_function(
rendezvous_builder_->server_context));
return rendezvous;
}
private:
const TF_RemoteRendezvousBuilder* rendezvous_builder_;
TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr);
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_

View File

@ -26,6 +26,7 @@ cc_library(
":function_metadata", ":function_metadata",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"@com_google_absl//absl/types:span",
], ],
) )
@ -113,8 +114,23 @@ cc_library(
deps = [ deps = [
":concrete_function", ":concrete_function",
":saved_model_api", ":saved_model_api",
":saved_model_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )
@ -131,6 +147,7 @@ cc_library(
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/types:span",
], ],
) )
@ -199,6 +216,23 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "signature_flattening_test",
srcs = [
"signature_flattening_test.cc",
],
deps = [
":saved_model_utils",
"//tensorflow/c/experimental/saved_model/core:tf_concrete_function_test_protos",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test( tf_cc_test(
name = "tf_concrete_function_loading_test", name = "tf_concrete_function_loading_test",
srcs = [ srcs = [

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
@ -38,10 +39,9 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = default; virtual ~ConcreteFunction() = default;
// This method returns the "Call" Op used to execute the function. // This method returns the "Call" Op used to execute the function.
virtual Status GetCallOp(ImmediateOpPtr* out) = 0; virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) = 0;
virtual const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
const = 0;
virtual const FunctionMetadata& GetFunctionMetadata() const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
}; };

View File

@ -69,6 +69,7 @@ cc_library(
], ],
deps = [ deps = [
":tensorhandle_convertible", ":tensorhandle_convertible",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
@ -77,5 +78,6 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
], ],
) )

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
@ -60,16 +62,12 @@ Status TFConcreteFunction::Create(
return Status(); return Status();
} }
const std::vector<ImmediateExecutionTensorHandle*>&
TFConcreteFunction::GetCaptures() const {
return captures_;
}
const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
return metadata_; return metadata_;
} }
Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) { Status TFConcreteFunction::GetCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) {
out->reset(ctx_->CreateOperation()); out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with // In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef: // the name of the functiondef:
@ -81,6 +79,16 @@ Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) {
// PartitionedCallOp for compatibility with "tooling that assumes functions in // PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps". // graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle**>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status(); return Status();
} }

View File

@ -58,10 +58,8 @@ class TFConcreteFunction : public ConcreteFunction {
std::unique_ptr<TFConcreteFunction>* out); std::unique_ptr<TFConcreteFunction>* out);
// This method returns the "Call" Op used to execute the function. // This method returns the "Call" Op used to execute the function.
Status GetCallOp(ImmediateOpPtr* out) override; Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) override;
const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
const override;
const FunctionMetadata& GetFunctionMetadata() const override; const FunctionMetadata& GetFunctionMetadata() const override;

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h" #include "tensorflow/core/protobuf/struct.pb.h"
@ -36,52 +37,8 @@ namespace tensorflow {
namespace internal { namespace internal {
namespace { namespace {
// This returns the size of `tf.nest.flatten(value)`, on values that are using StructuredValueDictEntry =
// used in tf.function's input_signatures. protobuf::MapPair<std::string, StructuredValue>;
int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) {
// This follows the logic from
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
switch (value.kind_case()) {
case StructuredValue::kDictValue: {
const DictValue& dict = value.dict_value();
int size = 0;
for (const auto& field : dict.fields()) {
size += FlattenedSize(field.second, status);
}
return size;
}
case StructuredValue::kTupleValue: {
const TupleValue& tuple = value.tuple_value();
int size = 0;
for (const StructuredValue& value : tuple.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kListValue: {
const ListValue& list = value.list_value();
int size = 0;
for (const StructuredValue& value : list.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kTensorSpecValue: {
return 1;
}
case StructuredValue::kNoneValue: {
// Base case: do nothing.
// This arises, for example, as the top-level object of an output
// signature when there are no return values.
return 0;
}
default: {
status->Update(errors::Internal("Unhandled structured value kind ",
value.kind_case()));
return 0;
}
}
}
// Perform some basic sanity checks on SavedConcreteFunction's input and // Perform some basic sanity checks on SavedConcreteFunction's input and
// output signatures with respect to the corresponding FunctionDef's input // output signatures with respect to the corresponding FunctionDef's input
@ -111,34 +68,34 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
const std::string& name = function_def->signature().name(); const std::string& name = function_def->signature().name();
const StructuredValue& input_signature = const StructuredValue& input_signature =
saved_concrete_function.canonicalized_input_signature(); saved_concrete_function.canonicalized_input_signature();
Status status; std::vector<const TensorSpecProto*> input_specs;
int input_signature_size = FlattenedSize(input_signature, &status); TF_RETURN_IF_ERROR(FlattenSignature(input_signature, &input_specs));
TF_RETURN_IF_ERROR(status); if (input_specs.size() + saved_concrete_function.bound_inputs_size() !=
if (input_signature_size + saved_concrete_function.bound_inputs_size() !=
function_def->signature().input_arg_size()) { function_def->signature().input_arg_size()) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"FunctionDef ", name, " has ", "FunctionDef ", name, " has ",
function_def->signature().input_arg_size(), function_def->signature().input_arg_size(),
" inputs, but the SavedConcreteFunction has ", input_signature_size, " inputs, but the SavedConcreteFunction has ", input_specs.size(),
" flattened user inputs and ", " flattened user inputs and ",
saved_concrete_function.bound_inputs_size(), " captured inputs."); saved_concrete_function.bound_inputs_size(), " captured inputs.");
} }
const StructuredValue& output_signature = const StructuredValue& output_signature =
saved_concrete_function.output_signature(); saved_concrete_function.output_signature();
int output_signature_size = FlattenedSize(output_signature, &status); std::vector<const TensorSpecProto*> output_specs;
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(FlattenSignature(output_signature, &output_specs));
if (output_signature_size != function_def->signature().output_arg_size()) { if (output_specs.size() != function_def->signature().output_arg_size()) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"FunctionDef ", name, " has ", "FunctionDef ", name, " has ",
function_def->signature().output_arg_size(), function_def->signature().output_arg_size(),
" outputs, but the SavedConcreteFunction has ", output_signature_size, " outputs, but the SavedConcreteFunction has ", output_specs.size(),
" flattened outputs."); " flattened outputs.");
} }
return status; return Status();
} }
} // namespace } // namespace
@ -197,6 +154,62 @@ Status LoadTFConcreteFunction(
out); out);
} }
Status FlattenSignature(const StructuredValue& signature,
std::vector<const TensorSpecProto*>* flattened_specs) {
// This follows the logic from
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
switch (signature.kind_case()) {
case StructuredValue::kDictValue: {
// Dictionaries must be sorted in order of keys
const DictValue& dict = signature.dict_value();
std::vector<const StructuredValueDictEntry*> entries;
entries.reserve(dict.fields_size());
for (const auto& field : dict.fields()) {
entries.push_back(&field);
}
std::sort(entries.begin(), entries.end(),
[](const StructuredValueDictEntry* x,
const StructuredValueDictEntry* y) {
return x->first < y->first;
});
for (const auto& entry : entries) {
TF_RETURN_IF_ERROR(FlattenSignature(entry->second, flattened_specs));
}
return Status();
}
case StructuredValue::kTupleValue: {
const TupleValue& tuple = signature.tuple_value();
for (const StructuredValue& value : tuple.values()) {
TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
}
return Status();
}
case StructuredValue::kListValue: {
const ListValue& list = signature.list_value();
for (const StructuredValue& value : list.values()) {
TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
}
return Status();
}
case StructuredValue::kTensorSpecValue: {
flattened_specs->push_back(&signature.tensor_spec_value());
return Status();
}
case StructuredValue::kNoneValue: {
// Base case: do nothing.
// This arises, for example, as the top-level object of an output
// signature when there are no return values.
return Status();
}
default: {
return errors::Internal("Unhandled structured value kind ",
signature.kind_case());
}
}
}
const SavedObject* FindNodeAtPath(StringPiece path, const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph) { const SavedObjectGraph& object_graph) {
const auto& nodes = object_graph.nodes(); const auto& nodes = object_graph.nodes();

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow { namespace tensorflow {
namespace internal { namespace internal {
@ -59,10 +60,17 @@ Status LoadTFConcreteFunction(
captured_objects, captured_objects,
ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out); ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out);
// Find the SavedObject in `object_graph` at location `path`. `path` must be a // Flattens `signature` into a vector of TensorSpecProto pointers back into
// dot-delimited string of object names relative to the root object. If no // `signature`. `signature` must outlive flattened_specs. `signature` must also
// object is found, returns nullptr. Callers must ensure `object_graph` outlives // be the input or output signature of a SavedConcreteFunction (i.e. "nested
// the returned pointer. // structures of tensorspecs").
Status FlattenSignature(const StructuredValue& signature,
std::vector<const TensorSpecProto*>* flattened_specs);
// Find the SavedObject in `object_graph` at location `path`. `path` must be
// a dot-delimited string of object names relative to the root object. If no
// object is found, returns nullptr. Callers must ensure `object_graph`
// outlives the returned pointer.
const SavedObject* FindNodeAtPath(StringPiece path, const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph); const SavedObjectGraph& object_graph);

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace {
// Validates names, shapes, and dtypes of two tensorspecprotos are equivalent.
bool TensorSpecsAreEqual(const TensorSpecProto& spec,
const std::string& expected_name,
const PartialTensorShape& expected_shape,
DataType expected_dtype) {
return spec.name() == expected_name &&
PartialTensorShape(spec.shape()).IsIdenticalTo(expected_shape) &&
spec.dtype() == expected_dtype;
}
// This tests the common case for a tf.function w/o inputs. This ends up
// being serialized as a tuple of an empty tuple + empty dictionary
// (corresponding to the args, kwargs) of the function.
TEST(SignatureFlatteningTest, ZeroArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ZeroArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 0);
}
// This tests the common case for a tf.function w/o outputs. This ends up
// being serialized as a "NoneValue".
TEST(SignatureFlatteningTest, ZeroRetOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ZeroReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 0);
}
TEST(SignatureFlatteningTest, SingleArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::SingleArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 1);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "x",
/* expected_shape = */ {1, 10},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
}
TEST(SignatureFlatteningTest, SingleReturnOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::SingleReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 1);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
}
TEST(SignatureFlatteningTest, ThreeArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ThreeArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 3);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "x",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1],
/* expected_name = */ "y",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[1]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2],
/* expected_name = */ "z",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[2]->DebugString();
}
// This test has an exotic outputsignature of tuple of a
// dictionary<string,tensor>, tensor
TEST(SignatureFlatteningTest, ThreeReturnOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ThreeReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 3);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "0/a",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1],
/* expected_name = */ "0/b",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[1]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2],
/* expected_name = */ "1",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[2]->DebugString();
}
} // namespace
} // namespace tensorflow

View File

@ -15,47 +15,364 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
namespace tensorflow { namespace tensorflow {
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
using FunctionDefMap =
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>;
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
// Graphdef
using NodeAttrMap =
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
// Maps from Node ID to an "Revived Object" implementing
// "TensorHandleConvertible"
using RevivedObjectMap =
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
using ConcreteFunctionMap =
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
namespace {
Status ConstantFromSavedConstant(
ImmediateExecutionContext* ctx,
const tensorflow::SavedConstant& saved_constant,
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
const std::string& const_op_name = saved_constant.operation();
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
if (node_name_and_attrs == node_attr_map.end()) {
return errors::FailedPrecondition(
"Unable to find Const operation with name'", const_op_name,
"' in SavedModel graphdef");
}
const AttrValueMap* attrs = node_name_and_attrs->second;
const auto& attr_name_and_value = attrs->find("value");
if (attr_name_and_value == attrs->end()) {
return errors::FailedPrecondition("Unable to find Const operation '",
const_op_name, "'s value attribute");
}
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
}
// Restores all non-function objects in the SavedModel's object graph.
// This function walks through the metagraph's saved object graph, and
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
// SavedResources. These are returned via the `out` parameter.
Status ReviveObjects(
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
revived_objects) {
// This is needed to restore "Constant" nodes by looking up their
// "Value" attribute.
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
// Iterate through all the saved objects, restoring objects as we go.
// We don't recreate functions until all other objects have been created.
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
const SavedObject& node = metagraph.object_graph_def().nodes(i);
if (node.kind_case() == SavedObject::kVariable) {
std::unique_ptr<Variable> variable;
TF_RETURN_IF_ERROR(
internal::LoadSavedVariable(context, node.variable(), &variable));
(*revived_objects)[i] = std::move(variable);
} else if (node.kind_case() == SavedObject::kConstant) {
std::unique_ptr<Constant> constant;
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
node_attr_map, &constant));
(*revived_objects)[i] = std::move(constant);
} else if (node.kind_case() == SavedObject::kAsset) {
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
// the full path to the asset file:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
// and storing it as a string tensor:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
return errors::Unimplemented("SavedAsset loading is not implemented yet");
} else if (node.kind_case() == SavedObject::kResource) {
// TODO(bmzhao): Figure out how resource loading works and implement it
return errors::Unimplemented(
"SavedResource loading is not implemented yet");
}
}
return Status();
}
Status ReviveFunctions(const MetaGraphDef& metagraph,
const RevivedObjectMap& revived_objects,
ImmediateExecutionContext* context,
ConcreteFunctionMap* restored_functions) {
const FunctionDefMap function_def_map =
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
// Iterate through all objects, only examining functions.
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
const std::string& function_name =
node.bare_concrete_function().concrete_function_name();
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
} else if (node.kind_case() == SavedObject::kFunction) {
// We only allow loading functions that have an annotated input signature,
// which means there is 1:1 correspondence between tf.function
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
// the same restriction that MLIR has:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
const SavedFunction& saved_function = node.function();
if (saved_function.concrete_functions_size() != 1) {
return errors::FailedPrecondition(
"Only tf.functions annotated with an input signature are supported "
"by SavedModelAPI. This means that there should only be a single "
"ConcreteFunction per tf.function");
}
const std::string& function_name = saved_function.concrete_functions(0);
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
}
}
return Status();
}
const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(
const TrackableObjectGraph::TrackableObject& trackable_object,
absl::string_view name) {
for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
if (maybe_serialized_tensor.name() == name) {
return &maybe_serialized_tensor;
}
}
return nullptr;
}
// This function reads the Checkpoint embedded in the SavedModel, and calls the
// appropriate Restore ops on each of the variables.
// Note(bmzhao): Conceptually, objects that contain checkpointable state
// implement the "_gather_saveables_for_checkpoint" method
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
// which returns a dict of string key -> EITHER:
// 1. python callable (taking a checkpoint key) returning SaveableObject OR
// 2. variable (partitioned/resource/reference or otherwise)
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
// The string key becomes the "name" attribute of the SerializedTensor proto
// in the TrackableObjectGraph,
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
// And the checkpoint_key is a globally unique string derived from this name:
// https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
// SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
// ops via their SaveSpec members
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
// which contain the "real" checkpoint keys into the TensorBundle SSTable.
// They also contain the logic needed to take the restored tensors from
// RestoreV2 and load them back into the "object" they came from via their
// overridden "restore" method:
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
const RevivedObjectMap& revived_objects,
const std::string& directory,
ImmediateExecutionContext* context) {
// TODO(bmzhao): Batch up all the restores into a single restore op per
// device, following logic in MultiDeviceSaver.
TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
[&revived_objects, &directory, context, bundle](
int node, const TrackableObjectGraph::TrackableObject& trackable) {
if (bundle->saved_object_graph().nodes(node).kind_case() !=
SavedObject::kVariable) {
// TODO(bmzhao): This requires using the newly added Save/Restore
// functions from
// https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
LOG(WARNING) << "Restoring non-variable objects has not been "
"implemented yet. (Kind="
<< bundle->saved_object_graph().nodes(node).kind_case()
<< ")";
return Status::OK();
}
Variable* variable =
down_cast<Variable*>(revived_objects.at(node).get());
// Restore the tensor's value from the checkpoint
const TrackableObjectGraph::TrackableObject::SerializedTensor*
attribute =
FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
if (attribute == nullptr) {
return errors::FailedPrecondition(
"Could not find SerializedTensor with name VARIABLE_VALUE for "
"saved variable");
}
const std::string& checkpoint_key = attribute->checkpoint_key();
std::string variables_path_prefix =
io::JoinPath(directory, kSavedModelVariablesDirectory,
kSavedModelVariablesFilename);
ImmediateTensorHandlePtr restored_output;
TF_RETURN_IF_ERROR(internal::SingleRestore(
context, variables_path_prefix, checkpoint_key, variable->dtype(),
&restored_output));
// Assign the restored tensor's value to the variable
return variable->Assign(restored_output.get());
}));
return Status();
}
} // namespace
Status TFSavedModelAPI::GetFunction(const std::string& function_path, Status TFSavedModelAPI::GetFunction(const std::string& function_path,
ConcreteFunction** function) { ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a function. const SavedObject* object =
return errors::Unimplemented( internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
"Retrieving functions is unimplemented currently"); if (object == nullptr) {
return errors::NotFound("No saved object found at path ", function_path);
}
if (object->kind_case() == SavedObject::kBareConcreteFunction) {
*function =
concrete_functions_
.at(object->bare_concrete_function().concrete_function_name())
.get();
} else if (object->kind_case() == SavedObject::kFunction) {
*function =
concrete_functions_.at(object->function().concrete_functions(0)).get();
} else {
return errors::InvalidArgument(function_path,
" is not a path to a Function.");
}
return Status();
} }
Status TFSavedModelAPI::GetSignatureDefFunction( Status TFSavedModelAPI::GetSignatureDefFunction(
const std::string& signature_def_key, ConcreteFunction** function) { const std::string& signature_def_key, ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a signaturedef function. // TODO(bmzhao): Add support for retrieving a signaturedef function.
return errors::Unimplemented( return errors::Unimplemented(
"Retrieving functions is unimplemented currently"); "Retrieving SignatureDef functions is unimplemented currently");
} }
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() { std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
std::vector<ConcreteFunction*> result; std::vector<ConcreteFunction*> result;
result.reserve(functions_.size()); result.reserve(concrete_functions_.size());
for (ConcreteFunction& function : functions_) { for (auto& index_and_function : concrete_functions_) {
result.push_back(&function); result.push_back(index_and_function.second.get());
} }
return result; return result;
} }
TFSavedModelAPI::TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions)
: directory_(directory),
bundle_(std::move(bundle)),
revived_objects_(std::move(revived_objects)),
concrete_functions_(std::move(concrete_functions)) {}
Status TFSavedModelAPI::Load( Status TFSavedModelAPI::Load(
const std::string& directory, const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags, const absl::optional<std::unordered_set<std::string>>& tags,
ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) { ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl. // TODO(bmzhao): Add support for loading a TF1 SavedModel.
return errors::Unimplemented( if (tags) {
"TFSavedModelAPIImpl loading is unimplemented currently"); return errors::Unimplemented(
"Loading saved models with explicit tags will be supported in the "
"future");
}
SavedModelV2Bundle bundle;
TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
// TODO(bmzhao): Mangle loaded function names so that different
// models loaded in the same runtime Context don't clobber eachother.
// This occurs in python here:
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
RevivedObjectMap revived_objects;
TF_RETURN_IF_ERROR(
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
// TODO(bmzhao): When we later add support for loading resources, we need to
// handle the case where materializing a function's captures requires invoking
// other functions. This occurs when retrieving the resource handle for a
// TrackableResource:
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
// This requires restoring functions in a topological sort order by capture
// dependencies.
ConcreteFunctionMap function_map;
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
context, &function_map));
TF_RETURN_IF_ERROR(
RestoreCheckpoint(&bundle, revived_objects, directory, context));
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
std::move(revived_objects),
std::move(function_map)));
return Status();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,14 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
namespace tensorflow { namespace tensorflow {
@ -63,8 +68,19 @@ class TFSavedModelAPI : public SavedModelAPI {
~TFSavedModelAPI() override = default; ~TFSavedModelAPI() override = default;
private: private:
TFSavedModelAPI() = default; TFSavedModelAPI(
std::vector<ConcreteFunction> functions_; const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions);
std::string directory_;
SavedModelV2Bundle bundle_;
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects_;
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -38,16 +38,17 @@ cc_library(
":concrete_function_type", ":concrete_function_type",
":function_metadata", ":function_metadata",
":function_metadata_type", ":function_metadata_type",
":tensorhandle_list",
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status_internal", "//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata", "//tensorflow/c/experimental/saved_model/core:function_metadata",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/types:span",
], ],
) )
@ -164,38 +165,6 @@ cc_library(
], ],
) )
cc_library(
name = "tensorhandle_list",
srcs = [
"tensorhandle_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
cc_library(
name = "tensorhandle_list_type",
hdrs = [
"tensorhandle_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)
tf_cc_test( tf_cc_test(
name = "saved_model_api_test", name = "saved_model_api_test",
size = "small", size = "small",
@ -213,7 +182,6 @@ tf_cc_test(
"//tensorflow/c/eager:c_api_test_util", "//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api", "//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",

View File

@ -15,13 +15,15 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
@ -32,15 +34,18 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
&tensorflow::unwrap(func)->GetFunctionMetadata())); &tensorflow::unwrap(func)->GetFunctionMetadata()));
} }
const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func) {
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
}
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) { TF_Status* status) {
tensorflow::ImmediateOpPtr call_op(nullptr); tensorflow::ImmediateOpPtr call_op;
status->status = tensorflow::unwrap(func)->GetCallOp(&call_op); absl::Span<tensorflow::AbstractTensorHandle* const> input_span(
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(inputs)),
static_cast<size_t>(num_inputs));
status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op);
if (!status->status.ok()) {
return nullptr;
}
return tensorflow::wrap(call_op.release()); return tensorflow::wrap(call_op.release());
} }

View File

@ -16,10 +16,14 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include <string> #include <string>
#include <vector>
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -92,12 +96,42 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TF_SavedModel* saved_model = TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status); TF_LoadSavedModel(model_dir.c_str(), ctx, status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented. EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// That unblocks writing other tests that require a TF_SavedModel*, TF_ConcreteFunction* compute_fn =
// like loading a ConcreteFunction. This test at least checks that the TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
// C API builds and can be minimally run. EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
std::vector<TFE_TensorHandle*> compute_fn_inputs;
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
compute_fn_inputs.push_back(input_a);
compute_fn_inputs.push_back(input_b);
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(
compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
float output_value = *static_cast<float*>(TF_TensorData(result));
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
EXPECT_FLOAT_EQ(output_value, 8.0);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
TFE_DeleteTensorHandle(input_a);
TFE_DeleteTensorHandle(input_b);
TFE_DeleteOp(compute_fn_op);
TF_DeleteSavedModel(saved_model); TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status); TF_DeleteStatus(status);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);

View File

@ -24,7 +24,6 @@ exports_files(
"concrete_function_list.h", "concrete_function_list.h",
"function_metadata.h", "function_metadata.h",
"saved_model_api.h", "saved_model_api.h",
"tensorhandle_list.h",
], ],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
) )
@ -40,7 +39,6 @@ cc_library(
":concrete_function_list", ":concrete_function_list",
":function_metadata", ":function_metadata",
":saved_model_api", ":saved_model_api",
":tensorhandle_list",
], ],
) )
@ -63,8 +61,3 @@ alias(
name = "saved_model_api", name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
) )
alias(
name = "tensorhandle_list",
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
)

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -35,13 +34,15 @@ typedef struct TF_ConcreteFunction TF_ConcreteFunction;
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
TF_ConcreteFunction* func); TF_ConcreteFunction* func);
// Returns a list of TensorHandles implicitly captured by this function. // Returns a TFE_Op suitable for executing this function. Caller must provide
TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( // all function inputs in `inputs`, and must not add any additional inputs on
TF_ConcreteFunction* func); // the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList).
// The caller is responsible for deleting the returned TFE_Op. If op
// Returns a TFE_Op suitable for executing this function. // construction fails, `status` will be non-OK and the returned pointer will be
// null.
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
TF_ConcreteFunction* func, TF_Status* status); TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} // end extern "C" } // end extern "C"

View File

@ -1,43 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_TensorHandleList TF_TensorHandleList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
const TF_TensorHandleList* list);
// Returns the `i`th TFE_TensorHandle in the list.
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
const TF_TensorHandleList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_

View File

@ -97,6 +97,11 @@ void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
kernel_builder->cc_builder->HostMemory(arg_name); kernel_builder->cc_builder->HostMemory(arg_name);
} }
void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
int32_t priority_number) {
kernel_builder->cc_builder->Priority(priority_number);
}
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -234,6 +239,14 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t) DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
TF_StringView string_view_of_name;
string_view_of_name.data = cc_ctx->def().name().data();
string_view_of_name.len = cc_ctx->def().name().length();
return string_view_of_name;
}
TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i)); return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
@ -266,4 +279,4 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
return nullptr; return nullptr;
} }
return tf_tensor; return tf_tensor;
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
@ -107,6 +108,10 @@ TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint(
TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory( TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory(
TF_KernelBuilder* kernel_builder, const char* arg_name); TF_KernelBuilder* kernel_builder, const char* arg_name);
// Specify a priority number for this kernel.
TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority(
TF_KernelBuilder* kernel_builder, int32_t priority_number);
// Register the given kernel builder with the TensorFlow runtime. If // Register the given kernel builder with the TensorFlow runtime. If
// registration fails, the given status will be populated. // registration fails, the given status will be populated.
// //
@ -180,6 +185,10 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
TF_Status* status); TF_Status* status);
// Returns the unique operation name for this OpKernel.
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
TF_OpKernelConstruction* ctx);
// Allocates Tensor for output at given index. Caller takes ownership of // Allocates Tensor for output at given index. Caller takes ownership of
// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor). // returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor).
// //

View File

@ -24,6 +24,21 @@ tf_kernel_library(
], ],
) )
tf_kernel_library(
name = "summary_op",
prefix = "summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/kernels:tensor_shape_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)
tf_gen_op_libs( tf_gen_op_libs(
op_lib_names = ["bitcast"], op_lib_names = ["bitcast"],
deps = [ deps = [
@ -35,6 +50,15 @@ tf_gen_op_libs(
], ],
) )
tf_gen_op_libs(
op_lib_names = ["summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_cc_test( tf_cc_test(
name = "bitcast_op_test", name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"], srcs = ["bitcast_op_test.cc"],
@ -48,6 +72,62 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "summary_op_test",
srcs = ["summary_op_test.cc"],
deps = [
":summary_op",
"//tensorflow/c:kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "summary_op_benchmark_test",
size = "small",
srcs = ["summary_op_benchmark_test.cc"],
deps = [
":summary_op",
"//tensorflow/c:kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "tensor_shape_utils",
srcs = ["tensor_shape_utils.cc"],
hdrs = ["tensor_shape_utils.h"],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/c:tf_tensor",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "tensor_shape_utils_test",
srcs = ["tensor_shape_utils_test.cc"],
deps = [
":tensor_shape_utils",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Changes to the Android srcs here should be replicated in # Changes to the Android srcs here should be replicated in
# tensorflow/contrib/makefile/tf_op_files.txt. # tensorflow/contrib/makefile/tf_op_files.txt.
# #
@ -59,11 +139,17 @@ filegroup(
name = "android_all_op_kernels", name = "android_all_op_kernels",
srcs = [ srcs = [
"bitcast_op.cc", "bitcast_op.cc",
"summary_op.cc",
"tensor_shape_utils.cc",
"tensor_shape_utils.h",
], ],
) )
# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt) # LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt)
filegroup( filegroup(
name = "android_all_ops", name = "android_all_ops",
srcs = ["ops/bitcast.cc"], srcs = [
"ops/bitcast.cc",
"ops/summary.cc",
],
) )

View File

@ -22,8 +22,19 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
static void ComputeNewShape(TF_ShapeInferenceContext* ctx, static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape, size_t input_type_size, TF_ShapeHandle* shape, TF_DataType input_type,
size_t output_type_size, TF_Status* status) { TF_DataType output_type, TF_Status* status) {
size_t input_type_size = TF_DataTypeSize(input_type);
size_t output_type_size = TF_DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
std::ostringstream err;
err << "Cannot bitcast type " << input_type << " to " << output_type
<< " because one of the type sizes is zero";
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
return;
}
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
if (input_type_size < output_type_size) { if (input_type_size < output_type_size) {
TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status); TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status);
@ -37,9 +48,9 @@ static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status); TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status);
} else { } else {
std::ostringstream err; std::ostringstream err;
err << "Cannot bitcast due to shape. " err << "Cannot bitcast from " << input_type << " to " << output_type
<< TF_DimensionHandleValue(last_dim) << " does not match " << " due to shape. " << TF_DimensionHandleValue(last_dim)
<< divisor_val; << " does not match " << divisor_val;
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
} }
TF_DeleteDimensionHandle(last_dim); TF_DeleteDimensionHandle(last_dim);
@ -78,23 +89,8 @@ static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status); TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status);
} }
size_t input_type_size;
size_t output_type_size;
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
input_type_size = TF_DataTypeSize(input_type); ComputeNewShape(ctx, result, input_type, output_type, status);
output_type_size = TF_DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
std::ostringstream err;
err << "Cannot bitcast type " << input_type << " to " << output_type
<< " because one of the type sizes is zero";
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
}
}
if (TF_GetCode(status) == TF_OK) {
ComputeNewShape(ctx, result, input_type_size, output_type_size, status);
} }
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void scalar_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_ScalarSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("ScalarSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "tags: string");
TF_OpDefinitionBuilderAddInput(op_builder, "values: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &scalar_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "ScalarSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool SummaryScalarOpRegistered = []() {
if (SHOULD_REGISTER_OP("ScalarSummary")) {
Register_ScalarSummaryOp();
}
return true;
}();

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
namespace {
// Struct that stores the status and TF_Tensor inputs to the opkernel.
// Used to delete tensor and status in its destructor upon kernel return.
struct Params {
TF_Tensor* tags;
TF_Tensor* values;
TF_Status* status;
explicit Params(TF_OpKernelContext* ctx)
: tags(nullptr), values(nullptr), status(nullptr) {
status = TF_NewStatus();
TF_GetInput(ctx, 0, &tags, status);
if (TF_GetCode(status) == TF_OK) {
TF_GetInput(ctx, 1, &values, status);
}
}
~Params() {
TF_DeleteStatus(status);
TF_DeleteTensor(tags);
TF_DeleteTensor(values);
}
};
// dummy functions used for kernel registration
void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
void ScalarSummaryOp_Delete(void* kernel) {}
// Helper functions for compute method
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2);
// Returns a string representation of a single tag or empty string if there
// are multiple tags
std::string SingleTag(TF_Tensor* tags);
template <typename T>
void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
Params params(ctx);
if (TF_GetCode(params.status) != TF_OK) {
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
if (!IsSameSize(params.tags, params.values)) {
std::ostringstream err;
err << "tags and values are not the same shape: "
<< tensorflow::ShapeDebugString(params.tags)
<< " != " << tensorflow::ShapeDebugString(params.values)
<< SingleTag(params.tags);
TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
// Convert tags and values tensor to array to access elements by index
tensorflow::Summary s;
auto tags_array =
static_cast<tensorflow::tstring*>(TF_TensorData(params.tags));
auto values_array = static_cast<T*>(TF_TensorData(params.values));
// Copy tags and values into summary protobuf
for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) {
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& Ttags_i = tags_array[i];
v->set_tag(Ttags_i.data(), Ttags_i.size());
v->set_simple_value(static_cast<float>(values_array[i]));
}
TF_Tensor* summary_tensor =
TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0,
sizeof(tensorflow::tstring), params.status);
if (TF_GetCode(params.status) != TF_OK) {
TF_DeleteTensor(summary_tensor);
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
tensorflow::tstring* output_tstring =
reinterpret_cast<tensorflow::tstring*>(TF_TensorData(summary_tensor));
CHECK(SerializeToTString(s, output_tstring));
TF_DeleteTensor(summary_tensor);
}
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) {
if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) {
return false;
}
for (int d = 0; d < TF_NumDims(tensor1); d++) {
if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) {
return false;
}
}
return true;
}
std::string SingleTag(TF_Tensor* tags) {
if (TF_TensorElementCount(tags) == 1) {
const char* single_tag =
static_cast<tensorflow::tstring*>(TF_TensorData(tags))->c_str();
return tensorflow::strings::StrCat(" (tag '", single_tag, "')");
} else {
return "";
}
}
template <typename T>
void RegisterScalarSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"ScalarSummary", tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create,
&ScalarSummaryOp_Compute<T>, &ScalarSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(
builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
TF_RegisterKernelBuilder("ScalarSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Scalar Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the ScalarSummary kernel.
TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) {
RegisterScalarSummaryOpKernel<tensorflow::int64>();
RegisterScalarSummaryOpKernel<tensorflow::uint64>();
RegisterScalarSummaryOpKernel<tensorflow::int32>();
RegisterScalarSummaryOpKernel<tensorflow::uint32>();
RegisterScalarSummaryOpKernel<tensorflow::uint16>();
RegisterScalarSummaryOpKernel<tensorflow::int16>();
RegisterScalarSummaryOpKernel<tensorflow::int8>();
RegisterScalarSummaryOpKernel<tensorflow::uint8>();
RegisterScalarSummaryOpKernel<Eigen::half>();
RegisterScalarSummaryOpKernel<tensorflow::bfloat16>();
RegisterScalarSummaryOpKernel<float>();
RegisterScalarSummaryOpKernel<double>();
}
return true;
}();
} // namespace

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <string>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
@ -20,19 +22,20 @@ limitations under the License.
#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow { namespace tensorflow {
namespace {
static Graph* BM_ScalarSummaryOp(TensorShape shape, const char* tag, Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag,
float value) { float value) {
Graph* g = new Graph(OpRegistry::Global()); Graph* g = new Graph(OpRegistry::Global());
Tensor tags(DT_STRING, shape); Tensor tags(DT_STRING, shape);
Tensor values(DT_FLOAT, shape); Tensor values(DT_FLOAT, shape);
for (int i = 0; i < tags.NumElements(); ++i){ for (int i = 0; i < tags.NumElements(); ++i){
tags.flat<tstring>()(i) = tag; tags.flat<tstring>()(i) = tag;
values.flat<float>()(i) = value; values.flat<float>()(i) = value;
} }
Node* ret; Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "SummaryScalar") TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary")
.Input(test::graph::Constant(g, tags)) .Input(test::graph::Constant(g, tags))
.Input(test::graph::Constant(g, values)) .Input(test::graph::Constant(g, values))
.Attr("T", DT_FLOAT) .Attr("T", DT_FLOAT)
@ -42,23 +45,27 @@ static Graph* BM_ScalarSummaryOp(TensorShape shape, const char* tag,
// Macro used to parse initializer list for tensorshape // Macro used to parse initializer list for tensorshape
#define DIMARGS(...) {__VA_ARGS__} #define DIMARGS(...) {__VA_ARGS__}
// Random parameters for testing // // Random parameters for testing
constexpr char longTagParam = "LONGTAG____________________________"; constexpr char longTagParam[] = "LONGTAG____________________________";
constexpr float largeValueParam = 2352352.2623433; constexpr float largeValueParam = 2352352.2623433;
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ #define BM_ScalarSummaryDev(device, dims, name, tag, value) \
static void BM_ScalarSummary_##name##_##device(int iters) { \ void BM_ScalarSummary##name##device(int iters) { \
TensorShape tensorshape(DIMARGS(dims)); \ testing::StopTiming(); \
test::Benchmark(#device, BM_ScalarSummaryOp( \ TensorShape tensorshape(DIMARGS dims); \
tensorshape, #tag, value)).Run(iters); \ auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
} \ testing::StartTiming(); \
BENCHMARK(BM_ScalarSummary_##name##_##device); test::Benchmark("cpu", g).Run(iters); \
} \
BENCHMARK(BM_ScalarSummary##name##device);
BM_ScalarSummaryDev(cpu, (5, 10, 100), Base, tag, 5.2); BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2);
// Benchmark for large shapes // Benchmark for large shapes
BM_ScalarSummaryDev(cpu, (500, 1000, 10000), Large_Shape, tag, 5.2); BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2);
// Benchmark for large tag tstring // Benchmark for large tag tstring
BM_ScalarSummaryDev(cpu, (5, 10, 100), Long_Tag, longTagParam, 5.2); BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2);
// Benchmark for large values // Benchmark for large values
BM_ScalarSummaryDev(cpu, (500, 1000, 10000), Large_Value, tag, largeValueParam); BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam);
} // namespace tensorflow
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,186 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace {
class DummyDevice : public DeviceBase {
public:
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
};
// Helper for comparing ouput and expected output
void ExpectSummaryMatches(const Summary& actual, const string& expected_str) {
Summary expected;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected));
EXPECT_EQ(expected.DebugString(), actual.DebugString());
}
void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output,
error::Code expected_code) {
// Initialize node used to fetch OpKernel
Status status;
NodeDef def;
def.set_op("ScalarSummary");
def.set_device(DEVICE_CPU);
AttrValue valuesTypeAttr;
SetAttrValue(values->dtype(), &valuesTypeAttr);
(*def.mutable_attr())["T"] = valuesTypeAttr;
def.add_input(strings::StrCat("input1: ", DataTypeString(tags->dtype())));
def.add_input(strings::StrCat("input2: ", DataTypeString(values->dtype())));
std::unique_ptr<OpKernel> kernel =
CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status);
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr);
params.device = &dummy_device;
params.op_kernel = kernel.get();
AllocatorAttributes alloc_attrs;
params.output_attr_array = &alloc_attrs;
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.emplace_back(tags);
inputs.emplace_back(values);
params.inputs = &inputs;
OpKernelContext ctx(&params, 1);
kernel->Compute(&ctx);
ASSERT_EQ(expected_code, ctx.status().code());
if (expected_code == error::OK) {
Summary summary;
ASSERT_TRUE(ParseProtoUnlimited(
&summary, ctx.mutable_output(0)->scalar<tstring>()()));
ExpectSummaryMatches(summary, expected_output);
} else {
EXPECT_TRUE(absl::StrContains(ctx.status().ToString(), expected_output))
<< ctx.status();
}
}
TEST(ScalarSummaryOpTest, SimpleFloat) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_FLOAT, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<float>()(0) = 1.0f;
values.vec<float>()(1) = -0.73f;
values.vec<float>()(2) = 10000.0f;
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, SimpleDouble) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_DOUBLE, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<double>()(0) = 1.0;
values.vec<double>()(1) = -0.73;
values.vec<double>()(2) = 10000.0;
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, SimpleHalf) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_HALF, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<Eigen::half>()(0) = Eigen::half(1.0);
values.vec<Eigen::half>()(1) = Eigen::half(-2.0);
values.vec<Eigen::half>()(2) = Eigen::half(10000.0);
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -2.0}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, Error_WrongDimsTags) {
Tensor tags(DT_STRING, {2, 1});
Tensor values(DT_FLOAT, {2});
tags.matrix<tstring>()(0, 0) = "tag1";
tags.matrix<tstring>()(1, 0) = "tag2";
values.vec<float>()(0) = 1.0f;
values.vec<float>()(1) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, Error_WrongValuesTags) {
Tensor tags(DT_STRING, {2});
Tensor values(DT_FLOAT, {2, 1});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
values.matrix<float>()(0, 0) = 1.0f;
values.matrix<float>()(1, 0) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, Error_WrongWithSingleTag) {
Tensor tags(DT_STRING, {1});
Tensor values(DT_FLOAT, {2, 1});
tags.vec<tstring>()(0) = "tag1";
values.matrix<float>()(0, 0) = 1.0f;
values.matrix<float>()(1, 0) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, IsRegistered) {
const OpRegistrationData* reg;
TF_CHECK_OK(OpRegistry::Global()->LookUp("ScalarSummary", &reg));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include <string>
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
namespace tensorflow {
std::string ShapeDebugString(TF_Tensor* tensor) {
// A TF_Tensor cannot have an unknown rank.
CHECK_GE(TF_NumDims(tensor), 0);
tensorflow::string s = "[";
for (int i = 0; i < TF_NumDims(tensor); ++i) {
if (i > 0) tensorflow::strings::StrAppend(&s, ",");
int64_t dim = TF_Dim(tensor, i);
// A TF_Tensor cannot have an unknown dimension.
CHECK_GE(dim, 0);
tensorflow::strings::StrAppend(&s, dim);
}
tensorflow::strings::StrAppend(&s, "]");
return s;
}
} // namespace tensorflow

View File

@ -13,25 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ // This file contains shape utilities to be used by kernels and is not part of
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ // the C API. As such, it is subject to change at any time.
#include <vector> #ifndef TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_
#define TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_
#include "tensorflow/c/conversion_macros.h" #include <string>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
// Internal structures used by the SavedModel C API. These are likely to #include "tensorflow/c/tf_tensor.h"
// change and should not be depended on.
typedef struct TF_TensorHandleList TF_TensorHandleList;
namespace tensorflow { namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS( // The following are utils for the shape of a TF_Tensor type.
std::vector<tensorflow::ImmediateExecutionTensorHandle*>, // These functions may later be subsumed by the methods for a
TF_TensorHandleList) // TF_TensorShape type.
// Returns a string representation of the TF_Tensor shape.
std::string ShapeDebugString(TF_Tensor* tensor);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ #endif // TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// A wrapper that will automatically delete the allocated TF_Tensor
// once out of scope.
struct TF_TensorWrapper {
TF_Tensor* tf_tensor;
explicit TF_TensorWrapper(TF_Tensor* tensor) { tf_tensor = tensor; }
~TF_TensorWrapper() { TF_DeleteTensor(tf_tensor); }
};
void TestShapeMatch(TensorShape shape) {
Tensor tensor(DT_FLOAT, shape);
Status status;
TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status);
TF_TensorWrapper tensor_wrapper = TF_TensorWrapper(tf_tensor);
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_EQ(tensor.shape().DebugString(), ShapeDebugString(tf_tensor));
}
TEST(ShapeDebugString, RegularShape) { TestShapeMatch(TensorShape({5, 4, 7})); }
TEST(ShapeDebugString, ScalarShape) { TestShapeMatch(TensorShape({})); }
} // namespace
} // namespace tensorflow

View File

@ -73,6 +73,12 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
EXPECT_EQ(TF_FLOAT, type); EXPECT_EQ(TF_FLOAT, type);
TF_DeleteStatus(status); TF_DeleteStatus(status);
// Exercise kernel NodeDef name read
TF_StringView name_string_view = TF_OpKernelConstruction_GetName(ctx);
std::string node_name = "SomeNodeName";
std::string candidate_node_name =
std::string(name_string_view.data, name_string_view.len);
EXPECT_EQ(node_name, candidate_node_name);
return s; return s;
} }
@ -96,9 +102,11 @@ namespace tensorflow {
static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name, static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
const char* op_name, const char* op_name,
const char* node_name,
Status* status) { Status* status) {
NodeDef def; NodeDef def;
def.set_op(op_name); def.set_op(op_name);
def.set_name(node_name);
def.set_device(device_name); def.set_device(device_name);
def.add_input("input1"); def.add_input("input1");
def.add_input("input2"); def.add_input("input2");
@ -114,7 +122,7 @@ static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
// Tests registration of a single C kernel and checks that calls through the // Tests registration of a single C kernel and checks that calls through the
// C/C++ boundary are being made. // C/C++ boundary are being made.
TEST(TestKernel, TestRegisterKernelBuilder) { TEST(TestKernel, TestRegisterKernelBuilder) {
const char* kernel_name = "SomeKernelName"; const char* node_name = "SomeNodeName";
const char* op_name = "FooOp"; const char* op_name = "FooOp";
const char* device_name = "FakeDeviceName1"; const char* device_name = "FakeDeviceName1";
@ -129,7 +137,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
{ {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status); TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
@ -144,7 +152,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
{ {
Status status; Status status;
std::unique_ptr<OpKernel> kernel = std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, &status); GetFakeKernel(device_name, op_name, node_name, &status);
TF_EXPECT_OK(status); TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get()); ASSERT_NE(nullptr, kernel.get());
kernel->Compute(nullptr); kernel->Compute(nullptr);
@ -162,7 +170,7 @@ class DummyDevice : public DeviceBase {
}; };
TEST(TestKernel, TestInputAndOutputCount) { TEST(TestKernel, TestInputAndOutputCount) {
const char* kernel_name = "InputOutputCounterKernel"; const char* node_name = "InputOutputCounterKernel";
const char* op_name = "BarOp"; const char* op_name = "BarOp";
const char* device_name = "FakeDeviceName2"; const char* device_name = "FakeDeviceName2";
@ -212,7 +220,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
{ {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status); TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
@ -233,7 +241,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
Status status; Status status;
std::unique_ptr<OpKernel> kernel = std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, &status); GetFakeKernel(device_name, op_name, node_name, &status);
TF_EXPECT_OK(status); TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get()); ASSERT_NE(nullptr, kernel.get());
@ -252,7 +260,7 @@ TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) {
} }
TEST(TestKernel, TestTypeConstraint) { TEST(TestKernel, TestTypeConstraint) {
const char* kernel_name = "SomeKernelName"; const char* node_name = "SomeNodeName";
const char* op_name = "TypeOp"; const char* op_name = "TypeOp";
const char* device_name = "FakeDeviceName1"; const char* device_name = "FakeDeviceName1";
@ -267,7 +275,7 @@ TEST(TestKernel, TestTypeConstraint) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::TF_INT32, status); TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::TF_INT32, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_RegisterKernelBuilder(kernel_name, builder, status); TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
@ -296,7 +304,7 @@ TEST(TestKernel, TestTypeConstraint) {
} }
TEST(TestKernel, TestHostMemory) { TEST(TestKernel, TestHostMemory) {
const char* kernel_name = "SomeKernelName"; const char* node_name = "SomeNodeName";
const char* op_name = "HostMemoryOp"; const char* op_name = "HostMemoryOp";
const char* device_name = "FakeDeviceName1"; const char* device_name = "FakeDeviceName1";
@ -311,7 +319,7 @@ TEST(TestKernel, TestHostMemory) {
TF_KernelBuilder_HostMemory(builder, "input2"); TF_KernelBuilder_HostMemory(builder, "input2");
TF_KernelBuilder_HostMemory(builder, "output1"); TF_KernelBuilder_HostMemory(builder, "output1");
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status); TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
@ -335,12 +343,12 @@ TEST(TestKernel, TestHostMemory) {
class DeviceKernelOpTest : public OpsTestBase { class DeviceKernelOpTest : public OpsTestBase {
protected: protected:
void SetupOp(const char* op_name, const char* kernel_name, void SetupOp(const char* op_name, const char* node_name,
void (*compute_func)(void*, TF_OpKernelContext*)) { void (*compute_func)(void*, TF_OpKernelContext*)) {
TF_KernelBuilder* builder = TF_NewKernelBuilder( TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, device_name_, nullptr, compute_func, nullptr); op_name, device_name_, nullptr, compute_func, nullptr);
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status); TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);

59
tensorflow/c/logging.cc Normal file
View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/logging.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stringprintf.h"
static ::tensorflow::string BuildMessage(const char* fmt, va_list args) {
::tensorflow::string message;
::tensorflow::strings::Appendv(&message, fmt, args);
return message;
}
void TF_Log(TF_LogLevel level, const char* fmt, ...) {
if (level < TF_INFO || level > TF_FATAL) return;
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
switch (level) {
case TF_INFO:
LOG(INFO) << message;
break;
case TF_WARNING:
LOG(WARNING) << message;
break;
case TF_ERROR:
LOG(ERROR) << message;
break;
case TF_FATAL:
LOG(FATAL) << message;
break;
}
}
void TF_VLog(int level, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
VLOG(level) << message;
}
void TF_DVLog(int level, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
DVLOG(level) << message;
}

View File

@ -12,25 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_LOGGING_H_
#define TENSORFLOW_C_LOGGING_H_
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" #include "tensorflow/c/c_api_macros.h"
#include <stddef.h> // --------------------------------------------------------------------------
// C API for tensorflow::Logging.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
#ifdef __cplusplus
extern "C" { extern "C" {
#endif
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) { typedef enum TF_LogLevel {
return tensorflow::unwrap(list)->size(); TF_INFO = 0,
TF_WARNING = 1,
TF_ERROR = 2,
TF_FATAL = 3,
} TF_LogLevel;
TF_CAPI_EXPORT extern void TF_Log(TF_LogLevel level, const char* fmt, ...);
TF_CAPI_EXPORT extern void TF_VLog(int level, const char* fmt, ...);
TF_CAPI_EXPORT extern void TF_DVLog(int level, const char* fmt, ...);
#ifdef __cplusplus
} }
#endif
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list, #endif // TENSORFLOW_C_LOGGING_H_
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"

View File

@ -104,6 +104,12 @@ TF_ShapeHandle* TF_NewShapeHandle() {
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle); return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
} }
TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
auto* handle = new ShapeHandle;
*handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
return reinterpret_cast<TF_ShapeHandle*>(handle);
}
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size) { TF_ShapeInferenceContext* ctx, size_t size) {
auto* handle = new ShapeHandle; auto* handle = new ShapeHandle;

View File

@ -280,6 +280,11 @@ extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
int i, TF_ShapeHandle* handle, int i, TF_ShapeHandle* handle,
TF_Status* status); TF_Status* status);
// Returns a newly-allocated scalar shape handle. The returned handle should
// be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar(
TF_ShapeInferenceContext* ctx);
// Returns a newly-allocate shape handle representing a vector of the given // Returns a newly-allocate shape handle representing a vector of the given
// size. The returned handle should be freed with TF_DeleteShapeHandle. // size. The returned handle should be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(

Some files were not shown because too many files have changed in this diff Show More