Merge remote-tracking branch 'upstream/master' into toupstream/16x8_batch_matmul

This commit is contained in:
Thibaut Goetghebuer-Planchon 2020-09-18 13:04:42 +01:00
commit f2b4f92407
4000 changed files with 171600 additions and 73168 deletions

View File

@ -5,6 +5,7 @@
# Android options:
# android:
# android_arm:
# android_arm64:
# android_x86:
# android_x86_64:
#
@ -18,8 +19,10 @@
#
# Compiler options:
# cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options
# c++1z: Build with C++17 options
# c++17: Build with C++17 options (links with libc++)
# c++1z: Build with C++17 options (links with libc++)
# c++17_gcc: Build with C++17 options (links with stdlibc++)
# c++1z_gcc: Build with C++17 options (links with stdlibc++)
# avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux.
# native_arch_linux: Build with instruction sets available to the host machine on linux
@ -44,10 +47,6 @@
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
# sycl: Build with SYCL support.
# sycl_nodouble:
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support.
# tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
@ -87,6 +86,7 @@
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
@ -159,13 +159,11 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true
build:mkl_threadpool --define=build_with_mkl_opensource=true
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
@ -173,7 +171,6 @@ build:mkl_threadpool -c opt
# Config setting to build with oneDNN and without the binary blob
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only -c opt
@ -214,19 +211,6 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --action_env TF_NEED_ROCM=1
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl --define=using_sycl=true
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
build:sycl_nodouble --config=sycl
build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
build:sycl_nodouble --config=sycl
build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
build:sycl_nodouble --config=sycl
build:sycl_trisycl --define=using_trisycl=true
# Options extracted from configure script
build:ngraph --define=with_ngraph_support=true
build:numa --define=with_numa_support=true
@ -278,6 +262,8 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
build:c++17_gcc --cxxopt=-std=c++1z
build:c++1z_gcc --config=c++17_gcc
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
@ -289,6 +275,7 @@ build:ios --noenable_platform_specific_config
build:android --copt=-w
build:ios --copt=-w
build:linux --copt=-w
build:linux --host_copt=-w
build:macos --copt=-w
build:windows --copt=/w
@ -330,6 +317,11 @@ build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
build:windows --copt=-DNOGDI
build:windows --host_copt=-DNOGDI
# MSVC (Windows): Standards-conformant preprocessor mode
# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview
build:windows --copt=/experimental:preprocessor
build:windows --host_copt=/experimental:preprocessor
# Misc build options we need for windows.
build:windows --linkopt=/DEBUG
build:windows --host_linkopt=/DEBUG
@ -354,6 +346,7 @@ build --config=short_logs
# TODO(gunan): Create a feature in toolchains for avx/avx2 to
# avoid having to define linux/win separately.
build:avx_linux --copt=-mavx
build:avx_linux --host_copt=-mavx
build:avx2_linux --copt=-mavx2
build:native_arch_linux --copt=-march=native
build:avx_win --copt=/arch=AVX
@ -368,7 +361,6 @@ build --config=v2
test --config=v2
# Enable XLA
build:xla --action_env=TF_ENABLE_XLA=1
build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
@ -408,9 +400,12 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt
build:rbe_linux --host_linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_linux --host_linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --host_crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
@ -428,6 +423,7 @@ test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
@ -444,6 +440,7 @@ build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
@ -458,12 +455,12 @@ build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
# Map default to CUDA 10.1.
# Map default to CUDA 11 for PY35 and greater.
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.0_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.0_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.0_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.0_nvcc_py3.8
# Deprecated configs that people might still use.
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
@ -580,11 +577,11 @@ build:release_cpu_macos --config=avx_linux
build:release_gpu_common --config=release_common
build:release_gpu_common --config=cuda
build:release_gpu_common --config=tensorrt
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
build:release_gpu_common --action_env=TF_CUDA_VERSION="11"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8"
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
@ -592,8 +589,7 @@ build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_linux --config=release_gpu_common
build:release_gpu_linux --config=avx_linux
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain
build:release_windows_common --config=release_common
build:release_windows_common --define=no_tensorflow_py_deps=true
build:release_windows_common --announce_rc
@ -601,3 +597,8 @@ build:release_windows_common --announce_rc
build:release_cpu_windows --config=release_windows_common
build:release_gpu_windows --config=release_windows_common
build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux
build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"

View File

@ -40,6 +40,22 @@ segfault_memory:
# assignees
filesystem_security_assignee:
- mihaimaruseac
tflite_micro_path:
- tensorflow/lite/micro
tflite_micro_comment: >
Thanks for contributing to TensorFlow Lite Micro.
To keep this process moving along, we'd like to make sure that you have completed the items on this list:
* Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md)
* Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
* Linked to the issue from the PR description
We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review.
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:

View File

@ -1,10 +0,0 @@
# TensorFlow Adopters
This page contains a list of people and organizations who are using TensorFlow. If you'd like to be included
here, please send a pull request which modifies this file.
We intend to use this list to contact you for surveys, and to find good candidates for invite-only events.
We will also point to this list if we are asked who uses TensorFlow.
We will not use any of the information here for promotions or to send other regular communications. You
should subscribe to discuss@tensorflow.org for such announcements.

View File

@ -1,16 +1,15 @@
# Where component owners are known, add them here.
/tensorflow/c/eager @jaingaurav @alextp
/tensorflow/core/common_runtime/eager @jaingaurav @alextp
/tensorflow/c/eager @qqfish @kkimdev
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
/tenosrflow/core/debug @caisq
/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mrry
/tensorflow/core/platform/windows/ @gunan @mihaimaruseac
/tensorflow/lite/experimental/micro @petewarden @advaitjain
/tensorflow/python/autograph/ @mdanatg @kkimdev
/tensorflow/python/debug @caisq
/tensorflow/python/eager @jaingaurav @alextp
/tensorflow/python/eager @rohan100jain @kkimdev
/tensorflow/python/tools/api/generator/ @annarev
/tensorflow/tensorboard/ @jart
/tensorflow/tools/docs/ @markdaoust
/third_party/systemlibs/ @perfinion

View File

@ -157,7 +157,7 @@ Build Type
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow Roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)

View File

@ -22,6 +22,7 @@
* Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy.
* Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values.
* Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now.
* Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead.
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
@ -33,6 +34,18 @@
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
well defined for complex types.
* `tf.data.experimental.service.DispatchServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.DispatchServer(dispatcher_config)`.
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.WorkerServer(worker_config)`.
## Known Caveats
@ -67,11 +80,24 @@
the same sparsity pattern, but with new provided values. It is similar to
the `with_values` function of `RaggedTensor`.
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory usage
of the device.
* `tf.data`:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
to register a dataset with the tf.data service, and another process to
consume data from the dataset.
* Added support for tf.data service dispatcher fault tolerance. To enable
fault tolerance, configure a `work_dir` when running your dispatcher
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
store its state to `work_dir`, so that on restart it can continue from its
previous state after restart.
* Added tf.data service support for sharing dataset graphs via shared
filesystem instead of over RPC. This reduces load on the dispatcher,
improving performance of distributing datasets. For this to work, the
dispatcher's `work_dir` must be accessible from workers. If the worker
fails to read from the `work_dir`, it falls back to using RPC for dataset
graph transfer.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
@ -79,11 +105,19 @@
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overriden.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
with a new `output_signature` argument, which allows `from_generator` to
produce any type describable by a `tf.TypeSpec`.
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
`tf.data.AUTOTUNE`.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Given the same seed, the stateless functions
produce the same results independent of how many times the function is
called, and independent of global seed settings.
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a determinstic
version of `sample_distorted_bounding_box` op. Given the same seed, these
stateless functions/ops produce the same results independent of how many
times the function is called, and independent of global seed settings.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
@ -95,16 +129,49 @@
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* The `steps_per_execution` argument in `compile()` is no longer
experimental; if you were passing `experimental_steps_per_execution`,
rename it to `steps_per_execution` in your code. This argument controls
the number of batches to run during each `tf.function` call when calling
`fit()`. Running multiple batches inside a single `tf.function` call can
greatly improve performance on TPUs or small models with a large Python
overhead.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the previous
iteration. These types of loops must run at least one iteration, and will
raise a runtime error otherwise.
Example:
```
for batch in data:
outputs = train_step(batch)
tf.print('final outputs', outputs)
```
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info.
* `tf.lite`:
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
* TFLite Profiler for Android is available. See the detailed
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
@ -116,14 +183,28 @@
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
(currently 'hlo' and 'optimized_hlo') for given input for given function.
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint`
object that is compatible with Keras `model.save_weights()` and
`model.load_weights`. The checkpoint is also compatible with the
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context.
* <ADD RELEASE NOTES HERE>
<ADD RELEASE NOTES HERE>
## Thanks to our Contributors
@ -215,6 +296,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Mutable tables now restore checkpointed values when loaded from SavedModel.
* GPU
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
* Remove environmental variable `TF_USE_CUDNN`.
* Others
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.
@ -1546,6 +1628,7 @@ Yuan (Terry) Tang, Yuchen Ying, Yves-Noel Weweler, zhangyujing, zjjott, zyeric,
color palette of the frame. This has been fixed now
* image.resize now considers proper pixel centers and has new kernels
(incl. anti-aliasing).
* Added an isotonic regression solver (tf.nn.isotonic_regression).
* Performance
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
dispatches the best kernel implementation based on CPU vector

View File

@ -16,5 +16,5 @@
set configure_dir=%~dp0
set configure_dir=%configure_dir:~0,-1%
python %configure_dir%\configure.py %* || ( exit /b )
python "%configure_dir%\configure.py" %* || ( exit /b )
echo Configuration finished

View File

@ -38,9 +38,6 @@ _DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_TENSORRT_VERSION = '6'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
@ -1114,62 +1111,6 @@ def set_host_c_compiler(environ_cp):
write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
def set_computecpp_toolkit_path(environ_cp):
"""Set COMPUTECPP_TOOLKIT_PATH."""
def toolkit_exists(toolkit_path):
"""Check if a computecpp toolkit path is valid."""
if is_linux():
sycl_rt_lib_path = 'lib/libComputeCpp.so'
else:
sycl_rt_lib_path = ''
sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path)
exists = os.path.exists(sycl_rt_lib_path_full)
if not exists:
print('Invalid SYCL %s library path. %s cannot be found' %
(_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
return exists
computecpp_toolkit_path = prompt_loop_or_load_from_env(
environ_cp,
var_name='COMPUTECPP_TOOLKIT_PATH',
var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH,
ask_for_var=(
'Please specify the location where ComputeCpp for SYCL %s is '
'installed.' % _TF_OPENCL_VERSION),
check_success=toolkit_exists,
error_msg='Invalid SYCL compiler path. %s cannot be found.',
suppress_default_error=True)
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
computecpp_toolkit_path)
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: ') % (
_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break
print('Invalid triSYCL include directory, %s cannot be found' %
(trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def system_specific_test_config(environ_cp):
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
@ -1397,8 +1338,6 @@ def main():
setup_python(environ_cp)
if is_windows():
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
@ -1415,21 +1354,6 @@ def main():
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla')
set_action_env_var(
environ_cp,
'TF_NEED_OPENCL_SYCL',
'OpenCL SYCL',
False,
bazel_config_name='sycl')
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
set_host_cxx_compiler(environ_cp)
set_host_c_compiler(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True)
if environ_cp.get('TF_NEED_COMPUTECPP') == '1':
set_computecpp_toolkit_path(environ_cp)
else:
set_trisycl_include_dir(environ_cp)
set_action_env_var(
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
if (environ_cp.get('TF_NEED_ROCM') == '1' and
@ -1442,6 +1366,11 @@ def main():
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
if ((environ_cp.get('TF_NEED_ROCM') == '1') and
(environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')):
write_to_bazelrc(
'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1')
environ_cp['TF_NEED_CUDA'] = str(
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
if (environ_cp.get('TF_NEED_CUDA') == '1' and
@ -1523,17 +1452,15 @@ def main():
# use it for the CPU build.
set_tf_download_clang(environ_cp)
# SYCL / ROCm / CUDA are mutually exclusive.
# ROCm / CUDA are mutually exclusive.
# At most 1 GPU platform can be configured.
gpu_platform_count = 0
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
gpu_platform_count += 1
if environ_cp.get('TF_NEED_ROCM') == '1':
gpu_platform_count += 1
if environ_cp.get('TF_NEED_CUDA') == '1':
gpu_platform_count += 1
if gpu_platform_count >= 2:
raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
raise UserInputError('CUDA / ROCm are mututally exclusive. '
'At most 1 GPU platform can be configured.')
set_cc_opt_flags(environ_cp)

View File

@ -562,6 +562,7 @@ selects.config_setting_group(
package_group(
name = "internal",
packages = [
"//learning/brain/distribute/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//tensorflow/...",
@ -578,11 +579,6 @@ package_group(
packages = ["//learning/pathways/..."],
)
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
package_group(

View File

@ -137,7 +137,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)

View File

@ -147,7 +147,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)

View File

@ -23,6 +23,7 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"c_api_macros.h",
"tensor_interface.h",
"tf_attrtype.h",
"tf_datatype.h",
@ -57,10 +58,11 @@ filegroup(
visibility = ["//visibility:public"],
)
filegroup(
cc_library(
name = "pywrap_required_hdrs",
srcs = [
textual_hdrs = [
"c_api_internal.h",
"c_api_macros.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
@ -79,6 +81,7 @@ tf_cuda_library(
hdrs = [
"c_api.h",
"c_api_internal.h",
"c_api_macros.h",
"tf_datatype.h",
"tf_tensor.h",
"tf_tstring.h",
@ -217,6 +220,7 @@ cc_library(
name = "logging",
srcs = ["logging.cc"],
hdrs = ["logging.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
"//tensorflow/core/platform:logging",
@ -310,6 +314,7 @@ cc_library(
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",
@ -336,6 +341,7 @@ tf_cuda_library(
],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",
@ -371,6 +377,7 @@ tf_cuda_library(
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:get_compiler_ir",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -381,6 +388,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"//tensorflow/core/platform:blocking_counter",
"@com_google_absl//absl/strings",
],
alwayslink = 1,

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
collective_executor_handle->get()->StartAbort(status->status);
}
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, [&done, status](const Status& s) {
status->status = s;
done.Notify();
});
done.WaitForNotification();
}
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
result->num_items = num_items;

View File

@ -231,13 +231,20 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
TF_Status* status);
// Aborts all ongoing collectives with the specified status. After abortion,
// subsequent collectives will error with this status immediately.
// subsequent collectives will error with this status immediately. To reset the
// collectives, create a new EagerContext.
//
// This is intended to be used when a peer failure is detected. There's yet no
// way to reset the collectives other than restarting the program.
// This is intended to be used when a peer failure is detected.
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status);
// Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status);
// Information about the shape of a Tensor and its type.
struct TF_ShapeAndType {
// Number of dimensions. -1 indicates unknown rank.

View File

@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1);
}
// This test only works when the TF build includes XLA compiler. One way to set
// this up is via bazel build option "--define with_xla_support=true".
//
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
// something like TENSORFLOW_CAPI_USE_XLA.
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
TF_Function* func;
const std::string funcName = "BranchFunc";
DefineFunction(funcName.c_str(), &func);
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* feed = Placeholder(host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_OperationDescription* desc =
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
TF_AddInput(desc, {true_cond, 0});
TF_Output inputs[] = {{feed, 0}};
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_SetAttrType(desc, "Tcond", TF_BOOL);
TF_DataType inputType = TF_INT32;
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
TF_SetDevice(desc, "/device:XLA_CPU:0");
auto op = TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(op, nullptr);
// Create a session for this graph.
CSession csession(host_graph_, s_, /*use_XLA*/ true);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Run the graph.
csession.SetInputs({{feed, Int32Tensor(17)}});
csession.SetOutputs({op});
csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_INT32, TF_TensorType(out));
EXPECT_EQ(0, TF_NumDims(out)); // scalar
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
EXPECT_EQ(-17, *output_contents);
// Clean up
csession.CloseAndDelete(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_DeleteFunction(func);
}
#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace
} // namespace tensorflow

View File

@ -30,4 +30,17 @@ limitations under the License.
#endif // _WIN32
#endif // SWIG
// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is
// the datatype for boolean tensors.
#ifndef TF_Bool
#define TF_Bool unsigned char
#endif // TF_Bool
// Macro used to calculate struct size for maintaining ABI stability across
// different struct implementations.
#ifndef TF_OFFSET_OF_END
#define TF_OFFSET_OF_END(TYPE, MEMBER) \
(offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER))
#endif // TF_OFFSET_OF_END
#endif // TENSORFLOW_C_C_API_MACROS_H_

View File

@ -6,7 +6,6 @@ load(
"tf_copts",
"tf_cuda_cc_test",
"tf_cuda_library",
"tfe_xla_copts",
)
load(
"//tensorflow/core/platform:build_config.bzl",
@ -31,7 +30,7 @@ tf_cuda_library(
"c_api_unified_experimental.h",
],
hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
@ -72,13 +71,6 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
],
}) + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/jit:xla_device",
],
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation",
@ -109,11 +101,17 @@ filegroup(
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"c_api_unified_experimental_internal.h",
"dlpack.h",
"gradients.h",
"gradients_internal.h",
"immediate_execution_context.h",
"immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"mnist_gradients_testutil.h",
"tape.h",
"tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h",
@ -171,31 +169,6 @@ cc_library(
],
)
cc_library(
name = "gradients",
srcs = [
"gradients.cc",
"gradients_internal.h",
],
hdrs = [
"gradients.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal",
":tape",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gradients_internal",
srcs = [
@ -228,7 +201,6 @@ tf_cuda_cc_test(
"gradients_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
@ -240,6 +212,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:array_grad",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/cc/profiler",
@ -249,6 +222,184 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "gradients_util",
srcs = [
"gradients_util.cc",
],
hdrs = [
"gradients_util.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":tape",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "mnist_gradients_testutil",
srcs = [
"mnist_gradients_testutil.cc",
],
hdrs = [
"mnist_gradients_testutil.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
":tape",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "gradient_checker",
srcs = [
"gradient_checker.cc",
],
hdrs = [
"gradient_checker.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test(
name = "gradient_checker_test",
size = "small",
srcs = [
"gradient_checker_test.cc",
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradient_checker",
":gradients_internal",
":gradients_util",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test(
name = "mnist_gradients_test",
size = "small",
srcs = [
"mnist_gradients_test.cc",
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"nomac",
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@ -482,7 +633,6 @@ tf_cuda_cc_test(
"c_api_debug_test.cc",
"c_api_test.cc",
],
extra_copts = tfe_xla_copts(),
tags = [
"noguitar", # TODO(b/155445984): flaky
#"guitar",
@ -508,6 +658,27 @@ tf_cuda_cc_test(
],
)
tf_cuda_library(
name = "c_api_remote_test_util",
testonly = 1,
srcs = ["c_api_remote_test_util.cc"],
hdrs = ["c_api_remote_test_util.h"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_remote_test",
size = "small",
@ -516,7 +687,6 @@ tf_cuda_cc_test(
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
@ -524,6 +694,7 @@ tf_cuda_cc_test(
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_remote_test_util",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
@ -540,6 +711,24 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_remote_function_test",
size = "small",
srcs = [
"c_api_remote_function_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
tags = [
"no_windows",
],
deps = [
":c_api_remote_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_cc_test(
name = "c_api_distributed_test",
size = "small",
@ -548,7 +737,6 @@ tf_cuda_cc_test(
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
@ -582,7 +770,6 @@ tf_cuda_cc_test(
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
@ -617,7 +804,7 @@ tf_cuda_library(
"c_api_experimental.h",
"c_api_unified_experimental.h",
],
copts = tf_copts() + tfe_xla_copts(),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
@ -689,7 +876,6 @@ tf_cuda_cc_test(
"c_api_experimental_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
@ -702,6 +888,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings",
],
)
@ -713,7 +900,6 @@ tf_cuda_cc_test(
"c_api_unified_experimental_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
@ -722,6 +908,7 @@ tf_cuda_cc_test(
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
@ -831,7 +1018,11 @@ filegroup(
"c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_internal.h",
"gradient_checker.cc",
"gradient_checker.h",
"gradients.cc", # Uses RTTI.
"gradients_util.cc",
"gradients_util.h",
"*test*",
"*dlpack*",
],

View File

@ -51,9 +51,6 @@ limitations under the License.
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -629,21 +626,30 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
"targets will fail.";
}
} else {
// The master's context_view_id will be incremented by one
// the UpdateRemoteMaster call later. We want all new workers and
// existing workers to also have the updated context_view_id, so
// we must set their context_view_id to the existing master's
// context_view_id + 1.
sg.Update(CreateRemoteContexts(
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
if (sg.ok()) {
// Create remote contexts on the newly added workers only if the master
// has collected all device information from them (i.e., the
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
// GetAllRemoteDevices can still fail even with RPCs configured to wait
// until the remote workers to become alive. If the master creates remote
// contexts on the workers whose devices are still not collected, those
// workers will be treated as existing workers subsequently, so the master
// will never get devices from them even with retrying UpdateServerDef.
sg.Update(CreateRemoteContexts(
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
}
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
VLOG(1) << "Updating cluster with existing worker " << w;
}
}
// The master's context_view_id will be incremented by one in the
// UpdateRemoteMaster call later. We want existing workers to also have
// the updated context_view_id, so we must set their context_view_id to
// the master's current context_view_id + 1.
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
removed_workers, context_id,
context_view_id + 1, server_def,
@ -724,7 +730,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
return tensorflow::wrap(new tfrt::ContextInterface(opts->async));
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
@ -745,7 +751,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator()));
@ -851,20 +856,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(context->GetServer());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers);
if (!status->status.ok()) {
LOG(ERROR) << "Failed to get client cache for remote workers.";
return false;
}
// TODO(yuefengz): support partially specified `worker_name`.
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
status->status = context->GetClient(worker_name, &eager_client);
if (!status->status.ok()) {
return false;
}
@ -1149,26 +1143,23 @@ void TFE_DeleteOp(TFE_Op* op) {
tensorflow::unwrap(op)->Release();
}
const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->Name().c_str();
}
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
return tensorflow::wrap(
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->DeviceName().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
#else
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
#endif // TENSORFLOW_EAGER_USE_XLA
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
}
@ -1181,6 +1172,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
static_cast<size_t>(num_inputs)});
}
extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->GetInputs().size();
}
extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
TF_Status* status) {
return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret = TF_ATTR_INT;
@ -1486,7 +1486,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::unwrap(ctx)->EndStep();
}
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
return tensorflow::wrap(
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
}
@ -1551,8 +1551,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
TFE_OpSetAttrFunction(op, attr_name, func_op);
TFE_DeleteOp(func_op);
} break;
case tensorflow::AttrValue::kList:
TF_FALLTHROUGH_INTENDED;
case tensorflow::AttrValue::kList: {
// String
if (const int s_size = default_value.list().s_size()) {
absl::InlinedVector<const void*, 4> values_vector;
absl::InlinedVector<size_t, 4> lengths_vector;
for (int i = 0; i < s_size; ++i) {
const string& v = default_value.list().s(i);
values_vector.push_back(v.data());
lengths_vector.push_back(v.size());
}
TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
lengths_vector.data(), s_size);
}
// Int
if (const int i_size = default_value.list().i_size()) {
absl::InlinedVector<int64_t, 4> i_vector;
for (int i = 0; i < i_size; ++i) {
i_vector.push_back(default_value.list().i(i));
}
TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
}
// Float
if (const int f_size = default_value.list().f_size()) {
absl::InlinedVector<float, 4> f_vector;
for (int i = 0; i < f_size; ++i) {
f_vector.push_back(default_value.list().f(i));
}
TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
}
// Bool
if (const int b_size = default_value.list().b_size()) {
absl::InlinedVector<unsigned char, 4> b_vector;
for (int i = 0; i < b_size; i++) {
b_vector.push_back(default_value.list().b(i));
}
TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
}
// Type
if (const int type_size = default_value.list().type_size()) {
absl::InlinedVector<unsigned int, 4> type_vector;
for (int i = 0; i < type_size; ++i) {
type_vector.push_back(default_value.list().type(i));
}
TFE_OpSetAttrTypeList(
op, attr_name,
reinterpret_cast<const TF_DataType*>(type_vector.data()),
type_size);
}
// Rest are not supported.
if (default_value.list().shape_size() > 0 ||
default_value.list().func_size() > 0 ||
default_value.list().tensor_size() > 0) {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Unable to get setfor default value: ",
default_value.DebugString())
.data());
}
} break;
case tensorflow::AttrValue::kTensor:
TF_FALLTHROUGH_INTENDED;
case tensorflow::AttrValue::kPlaceholder:
@ -1612,19 +1671,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
return status.status;
}
tensorflow::Status Execute(tensorflow::EagerOperation* op,
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
tensorflow::TensorHandle** retvals,
int* num_retvals) override {
std::vector<TFE_TensorHandle*> inputs;
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
@ -1634,10 +1686,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_DeleteTensorHandle(outputs[i]);
}
}
for (auto inp : inputs) {
TFE_DeleteTensorHandle(inp);
}
return status.status;
}

View File

@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op;
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
const char* op_or_function_name,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
// Returns the op or function name `op` will execute.
//
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
TF_Status* status);
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op,
TF_Status* status);
// When 'enable' is set to 1, and if TensorFlow library is built with XLA
// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
//
// If the library is not built with XLA support, this call would be a no-op.
TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
unsigned char enable);
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input,
TF_Status* status);
@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op,
int num_inputs,
TF_Status* status);
// Fetches the current number of inputs attached to `op`.
//
// Does not use the operation's definition to determine how many inputs should
// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an
// already-finalized operation.
//
// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat
// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a
// particular named input list, which may only be part of the op's inputs).
TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op,
TF_Status* status);
// Returns a borrowed reference to one of `op`'s inputs. Use
// `TFE_TensorHandleCopySharingTensor` to make a new reference.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op,
int index,
TF_Status* status);
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op,
const char* attr_name,
unsigned char* is_list,

View File

@ -22,9 +22,6 @@ limitations under the License.
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/status.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/jit/xla_device.h"
#endif // TENSORFLOW_EAGER_USE_XLA
using tensorflow::string;
@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
auto* device = absl::get<tensorflow::Device*>(handle->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
if (xla_device != nullptr) {
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<tensorflow::int64> shape_to_log =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
<< padded_shape.DebugString();
}
}
if (padded_shape.IsTuple()) {
if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
}
// shape0 is not a const& because we will assign it to padded_shape below.
// It is illegal to assign a part of a message to itself.
xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
}
// Since the only case we handle here are two equal subshapes, we
// simply return one of them. The caller will interpret it as this
// shape directly storing the 64bit types. This approximation is good
// enough for this API's debugging use case.
padded_shape = shape0;
}
int rank = padded_shape.dimensions_size();
std::vector<tensorflow::int64> dev_dims;
dev_dims.reserve(rank);
if (rank == 1) {
// Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
dev_dims.push_back(padded_shape.dimensions(0));
} else {
for (int i = rank - 1; i >= 0; --i) {
tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i);
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
status->status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<tensorflow::int64> dev_dims =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {

View File

@ -121,25 +121,6 @@ string AddVariablesFunction() {
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
@ -182,9 +163,8 @@ void TestFunctionWithPackedInput(const bool remote) {
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Pack 3 variable handles into one TFE_TensorHandle.
// When remote is false, function device is placed on task0. Handle types are
@ -396,6 +376,8 @@ TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
@ -517,8 +499,11 @@ void TestDistributedFunctionCancellation(bool inject_error) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string function_def = VariableAddFunctionWithGraphError();
const string function_def = inject_error ? VariableAddFunctionWithGraphError()
: VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -486,29 +486,6 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
}
void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
TFE_ContextMirroringPolicy policy) {
options->mirroring_policy = policy;
}
void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy));
}
// Note: this function looks up a thread local policy. So it should be called in
// the appropriate client thread. In particular, in async mode, it may not be
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
}
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
bool lazy_copy) {
options->lazy_remote_inputs_copy = lazy_copy;

View File

@ -265,33 +265,6 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
// LINT.IfChange
// Note: Keep in sync with internal copy of enum in eager/context.h.
typedef enum TFE_ContextMirroringPolicy {
// Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
// copies with their own lifetime.
TFE_MIRRORING_NONE = 0,
// Mirroring any remote tensor handles, associating them with the lifetime of
// the local TensorHandle.
TFE_MIRRORING_ALL = 1,
} TFE_ContextMirroringPolicy;
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetMirroringPolicy(
TFE_ContextOptions*, TFE_ContextMirroringPolicy);
// Sets a thread-local mirroring policy. After this call, other calls to
// TFE_Execute in the same thread will use the mirroring policy specified here
// instead of the mirroring policy used to construct the context. This has no
// effect on the mirroring policy used by other program threads.
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context*, TFE_ContextMirroringPolicy);
// Returns the mirroring policy to be used by this context in the current
// thread.
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context*);
// Sets whether to copy the remote inputs of a function lazily.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TFE_ContextOptions*, bool lazy_copy);
@ -441,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive.
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
@ -462,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 2
// TODO(b/166642410): It would be nice, for custom devices and for other users,
// to have a non-string representation of devices (TF_Device) extracted from
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
#define TFE_CUSTOM_DEVICE_VERSION 3
// Struct to be filled in
typedef struct TFE_CustomDevice {
@ -481,9 +458,16 @@ typedef struct TFE_CustomDevice {
void* device_info);
// Method to execute an operation.
void (*execute)(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
//
// Arguments provide enough information to reconstruct the original `TFE_Op`,
// or construct a transformed version, by inspecting the passed `op`.
//
// TFE_OpGetDevice(op) records the original placement of the operation. It may
// be an empty string if no device was explicitly requested, but will
// otherwise be the name of this custom device. Ops are placed onto a custom
// device if any of their inputs are on that custom device, but custom devices
// are free to set a bad status in order to require explicit placement.
void (*execute)(const TFE_Op* op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.

View File

@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) {
TF_DeleteStatus(status);
}
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST(CAPI, Function_ident_XLA_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
TF_OperationDescription* arg_descr =
TF_NewOperation(function_graph, "Placeholder", "arg");
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
TF_Status* status = TF_NewStatus();
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_OperationDescription* id_descr =
TF_NewOperation(function_graph, "Identity", "id");
TF_SetAttrType(id_descr, "T", TF_INT32);
TF_AddInput(id_descr, {arg, 0});
TF_Operation* id = TF_FinishOperation(id_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_Output input{arg, 0};
TF_Output output{id, 0};
TF_Function* fn =
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
&output, nullptr, nullptr, "test", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteGraph(function_graph);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextAddFunction(ctx, fn, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteFunction(fn);
for (bool async : {false, true, false}) {
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
TFE_Executor* executor = TFE_NewExecutor(async);
TFE_ContextSetExecutorForThread(ctx, executor);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
TF_Tensor* t =
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteTensor(t);
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_OpAddInput(op, h, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
// Now run it via XLA.
TFE_OpSetXLACompilation(op, true);
std::vector<TFE_TensorHandle*> result;
result.push_back(nullptr);
int num_retvals = 1;
TFE_Execute(op, result.data(), &num_retvals, status);
TFE_DeleteOp(op);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
ASSERT_EQ(num_retvals, 1);
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
TFE_ContextSetExecutorForThread(ctx, old_executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_DeleteExecutor(old_executor);
TFE_DeleteTensorHandle(h);
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
#endif // TENSORFLOW_EAGER_USE_XLA
void Executor_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -32,7 +32,6 @@ struct TFE_ContextOptions {
bool async = false;
TFE_ContextDevicePlacementPolicy device_placement_policy{
TFE_DEVICE_PLACEMENT_SILENT};
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
// If true, lazily copy the remote inputs of a function to the target devices.
bool lazy_remote_inputs_copy = true;
// If true, use TFRT backend

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace {
void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
heavy_load_on_streaming_rpc,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/true);
}
} // namespace

View File

@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -115,225 +117,24 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
void TestRemoteExecuteSilentCopiesOp(bool async, bool remote,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/true);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false);
}
} // namespace

View File

@ -0,0 +1,222 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using ::tensorflow::string;
string MatMulFunction(const string& matmul_device) {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
absl::StrCat(" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" device: '",
matmul_device, "'",
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }"),
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
const string matmul_device = remote_func_outputs ? task2_name : "";
string function_def = MatMulFunction(matmul_device);
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async && !remote_func_outputs) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
if (remote_func_outputs) {
const string backing_device =
TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(backing_device, task2_name);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
// Run a function containing a MatMul op and check its output.
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false);
#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
// clang-format off
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/platform/platform.h"
// clang-format on
@ -876,89 +877,6 @@ TEST(CAPI, Execute_Min_CPU) {
TF_DeleteStatus(status);
}
#ifdef TENSORFLOW_EAGER_USE_XLA
void Execute_MatMul_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_OpSetXLACompilation(matmul, true);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
// Running a primitive TF operator via XLA is not yet supported.
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
void Execute_Min_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = MinOp(ctx, input, axis);
TFE_OpSetXLACompilation(minOp, true);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(minOp, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(minOp);
TFE_DeleteTensorHandle(input);
TFE_DeleteTensorHandle(axis);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float output[2] = {0};
EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
#endif // TENSORFLOW_EAGER_USE_XLA
void ExecuteWithTracing(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1274,6 +1192,68 @@ TEST(CAPI, StringAttributes) {
TF_DeleteStatus(status);
}
// Same test as above, expect use SetOpAttrValueScalar to set attrs.
TEST(CAPI, TestTFE_SetOpAttrs) {
// Test that TFE_OpSetAttrString doesn't hold on to the value after it
// returns.
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::vector<int64_t> dims(4, 1);
TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* tensor =
TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
float tensor_data[] = {1};
memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, tensor_handle, status);
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(tensor_handle);
tensorflow::AttrValue i_list_values;
for (int i = 0; i < 4; ++i) {
i_list_values.mutable_list()->add_i(1);
}
SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
tensorflow::AttrValue padding_value;
*padding_value.mutable_s() = "VALID";
tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
tensorflow::AttrValue data_format_value;
*data_format_value.mutable_s() = "NHWC";
tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
status);
TFE_OpSetAttrType(op, "T", TF_FLOAT);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(op, &retvals[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
tensor = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(4, TF_TensorByteSize(tensor));
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(op);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -1620,4 +1600,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
TFE_DeleteContext(ctx);
}
// Needs to work with a const TFE_Op since custom devices should not modify the
// op they are called with.
TFE_Op* CloneOp(const TFE_Op* other) {
TF_Status* status = TF_NewStatus();
TFE_Context* context = TFE_OpGetContext(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char* op_name = TFE_OpGetName(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* ret = TFE_NewOp(context, op_name, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char* device = TFE_OpGetDevice(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetDevice(ret, device, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
int num_inputs = TFE_OpGetFlatInputCount(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (int input_index = 0; input_index < num_inputs; ++input_index) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(ret, input, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TF_DeleteStatus(status);
return ret;
}
TEST(CAPI, TestTFE_OpRecreation) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
// Clone an op with attributes and a device set.
TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetDevice(original_var_op,
"/job:localhost/replica:0/task:0/device:CPU:0", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* cloned = CloneOp(original_var_op);
EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
std::string(TFE_OpGetDevice(cloned, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_TensorHandle* ret;
TFE_Execute(cloned, &ret, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(ret);
// Clone an op with inputs and no device set.
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInputList(original_identity, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* cloned_identity = CloneOp(original_identity);
EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
num_retvals = 2;
TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(identity_ret[0]);
TFE_DeleteTensorHandle(identity_ret[1]);
TFE_DeleteOp(cloned_identity);
TFE_DeleteOp(original_identity);
TFE_DeleteOp(original_var_op);
TFE_DeleteOp(cloned);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -102,6 +102,32 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
return th;
}
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1];

View File

@ -40,6 +40,14 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims);
// Get a Matrix TensorHandle with given float values and dimensions
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims);
// Get a Matrix TensorHandle with given int values and dimensions
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims);
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);

View File

@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() {
return *factories;
}
static const char* default_factory = "<unset>";
static tracing::FactoryFunction default_factory;
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) ||
@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
GetFactories()[name] = factory;
}
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
Status SetDefaultTracingEngine(const char* name) {
auto entry = GetFactories().find(name);
if (entry != GetFactories().end()) {
default_factory = GetFactories().find(name)->second;
return Status::OK();
}
string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '",
default_factory, "' (available: ");
"No tracing engine factory has been registered with the key '", name,
"' (available: ");
// Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted;
for (const auto& factory : GetFactories())
@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name,
}
msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return errors::InvalidArgument(msg.c_str());
}
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
if (default_factory) {
return default_factory(fn_name, s);
}
Set_TF_Status_from_Status(
s, errors::FailedPrecondition("default_factory is nullptr"));
return nullptr;
}
@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext;
using tensorflow::tracing::TracingOperation;
using tensorflow::tracing::TracingTensorHandle;
void TF_SetTracingImplementation(const char* name) {
SetDefaultTracingEngine(name);
void TF_SetTracingImplementation(const char* name, TF_Status* s) {
Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
}
// Creates a new TensorFlow function, it is an execution context attached to a

View File

@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction;
// This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined
// here.
void TF_SetTracingImplementation(const char* name);
void TF_SetTracingImplementation(const char* name, TF_Status*);
// Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing

View File

@ -85,7 +85,11 @@ class GraphOperation : public TracingOperation {
return errors::FailedPrecondition(
"GraphOperation::Reset must be called before calling SetOpName.");
}
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
// TODO(b/145674566): We use Graph::NewName to get a unique name here but
// this may not be consistent with python's naming policy.
mutex_lock l(g_->mu);
op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
g_->graph.NewName(op_name).c_str()));
return Status::OK();
}
const string& Name() const override { return op_type_; }
@ -361,9 +365,10 @@ class GraphContext : public TracingContext {
}
auto s = TF_NewStatus();
func->func = TF_GraphToFunction(
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, name_.data(), s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*f = func.release();
@ -387,7 +392,7 @@ class GraphContext : public TracingContext {
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_;
const char* name_;
string name_;
};
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
@ -397,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef");
SetDefaultTracingEngine("graphdef").IgnoreError();
return true;
}();

View File

@ -120,7 +120,7 @@ class TracingContext : public AbstractContext {
};
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
Status SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
} // namespace tracing

View File

@ -22,19 +22,30 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
using tensorflow::Status;
using tensorflow::string;
using tensorflow::TF_StatusPtr;
namespace tensorflow {
namespace {
// The tests are parameterized on:
// - a string representing the tracing implementation: "mlir" or "graphdef".
// - a boolean that when true enables TFRT as the execution engine.
class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};
@ -554,7 +565,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
TF_AbstractOpSetOpName(add_op, "my_add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
@ -576,7 +587,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
TF_AbstractOpSetOpName(add_op, "my_add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
@ -983,6 +994,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
// The above tests are run for a combination of:
// - graphdef and MLIR tracing engine
// - Using TFRT as an execution runtime (true == enable TFRT)
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",

View File

@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
ASSERT_FALSE(arrived);
@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom_device_name,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom0,
/*strict_scope_placement=*/false, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom1,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails.
// Custom device: mix of custom/physical places the op on the custom device.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
EXPECT_TRUE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Explicit placement still forces the op onto the requested device
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
// configured to do)
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}

View File

@ -33,6 +33,9 @@ struct LoggingDevice {
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
// If true, only explicit op placements are accepted. If false, uses
// type-based dispatch.
bool strict_scope_placement;
};
struct LoggedTensor {
@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
if (dev->strict_scope_placement && *requested_placement == '\0') {
TF_SetStatus(s, TF_INTERNAL,
"Ops must be placed on the device explicitly, or their inputs "
"first copied to other devices.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
if (TF_GetCode(s) != TF_OK) return;
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) {
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
device->strict_scope_placement = strict_scope_placement;
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag,
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
logging_device->strict_scope_placement = true;
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);

View File

@ -109,7 +109,8 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
const char* device_name =
tensorflow::unwrap(h)->BackingDeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;

View File

@ -0,0 +1,201 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
using namespace std;
// ================== Helper functions =================
// Fills data with values [start,end) with given step size.
void Range(vector<int>* data, int start, int end, int step = 1) {
for (int i = start; i < end; i += step) {
(*data)[i] = i;
}
}
// Returns AbstractTensorHandlePtr containing [0, ..., n-1].
AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) {
vector<int> vals(n);
int64_t vals_shape[] = {n};
Range(&vals, 0, n);
AbstractTensorHandlePtr r =
GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1);
return r;
}
// Fills out_dims with the dimensions of the given tensor.
void GetDims(const TF_Tensor* t, int64_t* out_dims) {
int num_dims = TF_NumDims(t);
for (int i = 0; i < num_dims; i++) {
out_dims[i] = TF_Dim(t, i);
}
}
// Runs model as is if output is a scalar,
// else sums the output tensor before returning.
Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle*> inputs,
absl::Span<AbstractTensorHandle*> outputs,
bool use_function) {
GradientRegistry registry;
std::vector<AbstractTensorHandle*> model_outputs(1);
// Run the model.
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
absl::MakeSpan(model_outputs), use_function,
registry));
AbstractTensorHandle* model_out = model_outputs[0];
TF_Tensor* model_out_tensor;
TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor));
int num_dims_out = TF_NumDims(model_out_tensor);
// If the output is a scalar, then return the scalar output
if (num_dims_out == 0) {
outputs[0] = model_out;
return Status::OK();
}
// Else, reduce sum the output to get a scalar
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
AbstractTensorHandlePtr sum_dims =
GetRangeTensorHandleUtil(ctx, num_dims_out);
// Reduce sum the output on all dimensions.
std::vector<AbstractTensorHandle*> sum_inputs(2);
sum_inputs[0] = model_out;
sum_inputs[1] = sum_dims.get();
TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs),
absl::MakeSpan(model_outputs), "sum_output"));
outputs[0] = model_outputs[0];
return Status::OK();
}
// ========================= End Helper Functions==============================
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle*> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad) {
AbstractTensorHandle* theta =
inputs[input_index]; // parameter we are grad checking
// Convert from AbstractTensor to TF_Tensor.
TF_Tensor* theta_tensor;
TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor));
// Get number of elements and fill data.
int num_elems = TF_TensorElementCount(theta_tensor);
vector<float> theta_data(num_elems);
memcpy(theta_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
// Initialize space for the numerical gradient.
vector<float> dtheta_approx(num_elems);
// Get theta shape and store in theta_dims.
int num_dims = TF_NumDims(theta_tensor);
vector<int64_t> theta_dims(num_dims);
GetDims(theta_tensor, theta_dims.data());
// Initialize auxilary data structures.
vector<float> thetaPlus_data(num_elems);
vector<float> thetaMinus_data(num_elems);
std::vector<AbstractTensorHandle*> f_outputs(1);
// Numerical Grad Check
for (int i = 0; i < num_elems; i++) {
// Get relative epsilon value
float epsilon =
std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0
AbstractTensorHandlePtr two_eps =
GetScalarTensorHandleUtil(ctx, 2 * epsilon);
// Initialize theta[i] + epsilon.
memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaPlus_data[i] += epsilon;
AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims);
// Initialize theta[i] - epsilon.
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaMinus_data[i] -= epsilon;
AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims);
// Get f(theta + eps):
inputs[input_index] = thetaPlus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fPlus = f_outputs[0];
// Get f(theta - eps):
inputs[input_index] = thetaMinus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fMinus = f_outputs[0];
// Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
TF_RETURN_IF_ERROR(
ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"));
AbstractTensorHandle* fDiff = f_outputs[0];
// Calculate using the difference quotient definition:
// (f(theta + eps) - f(theta - eps)) / (2 * eps).
TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()},
absl::MakeSpan(f_outputs),
"diff_quotient"));
AbstractTensorHandle* diff_quotient = f_outputs[0];
TF_Tensor* grad_tensor;
TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor));
float grad_data[1];
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
TF_TensorByteSize(grad_tensor));
dtheta_approx[i] = grad_data[0];
}
// Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
/* Returns numerical grad inside `dtheta_approx` given `forward` model and
* parameter specified by `input_index`.
*
* I.e. if y = <output of the forward model> and w = inputs[input_index],
* this will calculate dy/dw numerically.
*
* `use_function` indicates whether to use graph mode(true) or eager(false).
*
* `numerical_grad` is the pointer to the AbstractTensorHandle* which will
* hold the numerical gradient data at the end of the function.
*/
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle*> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad);
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,265 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
class GradientCheckerTest
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(A.get());
inputs.push_back(B.get());
AbstractTensorHandle* grad_approx;
Status s = CalcNumericalGrad(
ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &grad_approx);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(grad_approx, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-2;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(expected_dA[j], result_data[j], tolerance);
}
TF_DeleteTensor(gt);
}
TEST_P(GradientCheckerTest, TestGradCheckMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
// Will perform z = x*y.
// dz/dx = y
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(x.get());
inputs.push_back(y.get());
AbstractTensorHandle* g;
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[1] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2);
TF_DeleteTensor(gt);
}
TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
/** Test to show how to use this API with analytical gradients:
*
* We have `SoftmaxLossGradModel`, which is a wrapper for the
* Softmax analytical gradient found in c/experimental/nn_grads.
*
* We will use the GradientChecker by applying finite differences
* to the forward pass wrapped in `SoftmaxModel` and verify that
* both the analytical and numerical gradients are relatively
* close.
*
*/
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = scores
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// y = labels
int y_vals[] = {1, 0, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(X.get());
inputs.push_back(y.get());
// Run analytical gradient and get its data.
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs),
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float danalytical[9] = {0}; // Contains data from analytical gradient.
memcpy(&danalytical[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
// Run numerical gradient approximation using the GradientChecker API.
AbstractTensorHandle* g; // Will contain numerical approximation data.
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float dnumerical[9] = {0};
memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt));
// Now compare the two implementations:
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2);
}
// Only Unref() first output as 2nd is nullptr grad for labels
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
TF_DeleteTensor(gt);
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
@ -23,25 +24,97 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
Status GradientRegistry::Register(const string& op_name,
GradientFunctionFactory factory) {
namespace {
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
AbstractTensorHandle** result) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(t)).c_str()));
}
TF_RETURN_IF_ERROR(op->AddInput(t));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
*result = outputs[0];
return Status::OK();
}
} // namespace
class IncomingGradientsImpl : public IncomingGradients {
public:
explicit IncomingGradientsImpl(
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
DefaultGradientFunction* default_gradients)
: grad_inputs_(grad_inputs),
ctx_(ctx),
default_gradients_(default_gradients) {}
AbstractTensorHandle* operator[](int i) const override {
return default_gradients_->get(ctx_, grad_inputs_, i);
}
size_t size() const override { return grad_inputs_.size(); }
private:
absl::Span<AbstractTensorHandle* const> grad_inputs_;
Context* ctx_;
DefaultGradientFunction* default_gradients_;
};
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
: outputs_(op.outputs) {
for (auto output : outputs_) {
output->Ref();
}
}
AbstractTensorHandle* AllZerosDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
if (grad_inputs[i]) {
return grad_inputs[i];
}
if (cached_default_grads_[i]) {
return cached_default_grads_[i].get();
}
AbstractTensorHandle* result = nullptr;
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
if (!s.ok()) {
if (result) {
result->Unref();
}
VLOG(1) << "Failed to create ZerosLike for index " << i;
return nullptr;
}
cached_default_grads_[i].reset(result);
return result;
}
PassThroughDefaultGradients::PassThroughDefaultGradients(
const ForwardOperation& op) {}
AbstractTensorHandle* PassThroughDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
return grad_inputs[i];
}
Status GradientRegistry::Register(
const string& op_name, BackwardFunctionFactory backward_function_factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, factory});
registry_.insert({op_name, backward_function_factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const {
std::unique_ptr<BackwardFunction>* backward_function) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
grad_fn->reset(iter->second(op));
backward_function->reset(iter->second(op));
return Status::OK();
}
@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const {
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
// Calls the passed-in backward function.
Status TapeVSpace::CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
Context ctx = {ctx_};
return backward_function->Compute(&ctx, output_gradients, result);
IncomingGradientsImpl incoming_gradients(
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
return backward_function->GetGradientFunction()->Compute(
&ctx, incoming_gradients, result);
}
// Looks up the ID of a Gradient.
@ -191,6 +242,7 @@ namespace internal {
Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) {
forward_op_->op_name = op;
forward_op_->attrs.Reset(op);
return op_->Reset(op, raw_device_name);
}
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
@ -363,21 +415,30 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType();
}
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]);
}
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
// attributes. Number type attrs and DataType attrs work fine without this.
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_->attrs.BuildNodeDef();
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx));
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> GradientFunction* {
std::unique_ptr<GradientFunction> grad_fn;
Status s = registry.Lookup(*forward_op_, &grad_fn);
[registry, forward_op_]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry.Lookup(*forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return grad_fn.release();
return backward_fn.release();
},
[](GradientFunction* ptr) {
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}

View File

@ -55,18 +55,25 @@ struct Context {
public:
AbstractContext* ctx;
};
class IncomingGradients {
public:
virtual AbstractTensorHandle* operator[](int i) const = 0;
virtual size_t size() const = 0;
virtual ~IncomingGradients() {}
};
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a GradientFunction.
// gradient registerer to instantiate a BackwardFunction.
struct ForwardOperation {
public:
string op_name;
@ -76,18 +83,86 @@ struct ForwardOperation {
AbstractContext* ctx;
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
// Interface for building default zeros gradients for op outputs which are
// missing incoming gradients. Custom implementations of this can be used to
// control which of the forward op's output tensors/their metadata needs to
// be kept around in memory to build the default zeros grad.
//
// Some common helper implementations are provided below.
class DefaultGradientFunction {
public:
Status Register(const string& op, GradientFunctionFactory factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const;
virtual AbstractTensorHandle* get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) = 0;
virtual ~DefaultGradientFunction() {}
};
// Returns zeros for any `nullptr` in `grad_inputs`.
//
// This may require keeping track of all of forward op's output
// tensors and hence may incur a higher memory footprint. Use sparingly.
//
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
// handle.
//
// The destructor of this class `Unref`'s any cached tensor handles so users of
// those tensor handles should `Ref` them in order to keep them alive if needed.
class AllZerosDefaultGradients : public DefaultGradientFunction {
public:
explicit AllZerosDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
private:
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
// TODO(srbs): We do not always need to keep the tensors around. In immediate
// execution mode we just need to store the shape and dtype. During tracing
// we may need to keep the tensor around if the shape is not full defined.
std::vector<AbstractTensorHandle*> outputs_;
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
};
// Passes through `grad_inputs` as-is. The `GradientFunction`
// will be expected to deal with nullptr in `grad_inputs` if any.
class PassThroughDefaultGradients : public DefaultGradientFunction {
public:
explicit PassThroughDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
};
// A `BackwardFunction` wraps a `GradientFunction` and a
// `DefaultGradientFunction`. Both are owned by this class' instance.
class BackwardFunction {
public:
BackwardFunction(GradientFunction* gradient_function,
DefaultGradientFunction* default_gradients)
: gradient_function_(gradient_function),
default_gradients_(default_gradients) {}
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
DefaultGradientFunction* GetDefaultGradientFunction() {
return default_gradients_.get();
}
private:
std::unique_ptr<GradientFunction> gradient_function_;
std::unique_ptr<DefaultGradientFunction> default_gradients_;
};
using BackwardFunctionFactory =
std::function<BackwardFunction*(const ForwardOperation& op)>;
// Map from op name to a `BackwardFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op,
BackwardFunctionFactory backward_function_factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<BackwardFunction>* backward_function) const;
private:
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
};
// Returns a unique id for the tensor which is used by the tape to build
@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// This also implements `ZerosLike` and `OnesLike` to create the default
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op.
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
// Figure out a way to avoid this.
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
class TapeTensor {
public:
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
@ -123,7 +205,7 @@ class TapeTensor {
private:
AbstractTensorHandle* handle_;
// The context where OnesLike and ZerosLike ops are to be created.
// The context where OnesLike ops are to be created.
AbstractContext* ctx_;
};
@ -132,7 +214,7 @@ class TapeTensor {
// gradient and for performing gradient aggregation.
// See `tensorflow::eager::VSpace` for more details.
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
@ -147,7 +229,7 @@ class TapeVSpace
// Calls the passed-in backward function.
Status CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override;
@ -168,8 +250,14 @@ class TapeVSpace
};
// A tracing/immediate-execution agnostic tape.
//
// Gradient functions defined for this library support handling null incoming
// gradients. `Tape::ComputeGradient` should be called with
// `build_default_zeros_grads=false`. Calling with
// `build_default_zeros_grads=true` (the default) is equivalent but just results
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor>;
BackwardFunction, TapeTensor>;
} // namespace gradients
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
@ -23,6 +24,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
@ -35,17 +37,26 @@ namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using std::vector;
using tensorflow::TF_StatusPtr;
using tracing::TracingOperation;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
return Status::OK();
}
// Computes `inputs[0] + inputs[1]` and records it on the tape.
@ -58,9 +69,9 @@ Status Add(AbstractContext* ctx, Tape* tape,
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(add_op.get())) {
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
@ -69,6 +80,46 @@ Status Add(AbstractContext* ctx, Tape* tape,
registry);
}
// Computes `exp(inputs[0])` and records it on the tape.
Status Exp(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr exp_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(exp_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
}
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `IdentityN(inputs)` and records it on the tape.
Status IdentityN(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(identity_n_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
->SetOpName("my_identity_n"));
}
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
int num_retvals = outputs.size();
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
tape, registry);
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
@ -91,7 +142,8 @@ Status AddGradModel(AbstractContext* ctx,
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto add_output : add_outputs) {
add_output->Unref();
}
@ -101,6 +153,71 @@ Status AddGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// y = exp(inputs[0])
// return grad(y, {inputs[0]})
Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1);
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
// This should return [nullptr, 1].
Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2);
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
absl::MakeSpan(identity_n_outputs), registry));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto identity_n_output : identity_n_outputs) {
identity_n_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -132,26 +249,42 @@ Status RunModel(Model model, AbstractContext* ctx,
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
OutputList output_list;
output_list.expected_num_outputs = outputs.size();
output_list.outputs.resize(outputs.size());
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
output_list.outputs[0]->Unref();
output_list.outputs[1]->Unref();
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
@ -160,8 +293,19 @@ Status RunModel(Model model, AbstractContext* ctx,
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size();
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
@ -264,18 +408,172 @@ TEST_P(CppGradients, TestAddGrad) {
TF_DeleteTensor(result_tensor);
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
TEST_P(CppGradients, TestExpGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = exp(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 2.718, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//
// tape.watch(x1)
// tape.watch(x2)
// unused, y = IdentityN([x1, x2])
// outputs = tape.gradient(y, [x1, x2])
// Expected: [nullptr, 1]
//
// This test is interesting because the current implementation of GradientTape
// would return [0, 1] whereas we use build_default_zeros_grads=false here
// so we get back [nullptr, 1].
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x1;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x1.reset(x_raw);
}
AbstractTensorHandlePtr x2;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x2.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ(outputs[0], nullptr);
TF_Tensor* result_tensor;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr t;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
t.reset(x_raw);
}
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx.get();
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
/*raw_device_name=*/nullptr, &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
if (isa<TracingOperation>(check_numerics_op.get())) {
s = dyn_cast<TracingOperation>(check_numerics_op.get())
->SetOpName("check_numerics");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
s = AddInput(check_numerics_op.get(), t.get(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string message = "This is the way!";
s = SetAttrString(check_numerics_op.get(), "message", message.data(),
message.length(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
int num_retvals = 1;
std::vector<AbstractTensorHandle*> outputs(1);
GradientRegistry registry;
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
&num_retvals, &forward_op, tape.get(), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string read_message;
s = forward_op.attrs.Get("message", &read_message);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_EQ(read_message, message);
}
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(true, false),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif

View File

@ -0,0 +1,317 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients_util.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
using namespace std;
Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
TFE_TensorHandle** result) {
float data[] = {value};
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return StatusFromTF_Status(status.get());
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val) {
AbstractTensorHandlePtr y;
AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx, val, &y_raw);
if (s.ok()) {
y.reset(y_raw);
}
return y;
}
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate) {
/* Update weights one by one using gradient update rule:
*
* w -= lr*grad[w]
*
* NOTE: assuming learning rate is positive
*/
int num_grads = grads.size();
vector<AbstractTensorHandle*> temp_outputs(1);
std::string update_str;
// Negate learning rate for gradient descent
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
absl::MakeSpan(temp_outputs),
"neg_lr")); // Compute -lr
learning_rate = temp_outputs[0];
for (int i = 0; i < num_grads; i++) {
// Compute dW = -lr * grad(w[i])
update_str = "update_mul_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
AbstractTensorHandle* dW = temp_outputs[0];
// Compute temp = weights[i] + dW
update_str = "update_add_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
// Update the weights
weights[i] = temp_outputs[0];
}
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
params->emplace_back(handle);
}
return Status::OK();
}
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,88 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gradients {
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor);
// Places data from `t` into *result_tensor.
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data.
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val);
// Performs gradient update for each weight using given learning rate.
Status UpdateWeights(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>& grads,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
// Runs given model in either graph or eager mode depending on value of
// use_function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry);
// Builds context and returns inside *ctx.
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients
} // namespace tensorflow

View File

@ -57,15 +57,10 @@ class ImmediateExecutionContext : public AbstractContext {
// Create a tensor instance from the given data buffer and description.
// `memory_releaser` will be called on destruction, and it's responsible for
// cleaning up the underlying buffer. `convert_string` indicates whether it
// has to handle tstring conversion. Expected to be removed once tstring
// migration is done.
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
const int64_t* dims,
int num_dims, void* data,
size_t len, bool convert_string,
MemoryReleaser memory_releaser,
void* memory_releaser_arg) = 0;
// cleaning up the underlying buffer.
virtual AbstractTensorInterface* CreateTensor(
DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(

View File

@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation {
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
// Set stack trace to be used for potential async error reporting.
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;

View File

@ -0,0 +1,723 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using tensorflow::TF_StatusPtr;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
TEST_P(CppGradients, TestMatMulGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(A)
* tape.watch(B)
* Y = AB
* outputs = tape.gradient(Y, [A, B])
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
}
TF_Tensor* dB_tensor;
s = GetValue(outputs[1], &dB_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(dB_tensor),
TF_TensorByteSize(dB_tensor));
float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(dA_tensor);
TF_DeleteTensor(dB_tensor);
}
TEST_P(CppGradients, TestMNISTForward) {
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t dims_y[] = {2};
num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
GradientRegistry registry;
// Run the Forward Pass
std::vector<AbstractTensorHandle*> outputs(2);
Status s =
RunModel(MNISTForwardModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
TF_Tensor* loss_vals_tensor;
s = GetValue(outputs[1], &loss_vals_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
TF_TensorByteSize(loss_vals_tensor));
float expected_losses[2] = {9.6f, 27.2f};
for (int j = 0; j < 2; j++) {
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(scores_tensor);
TF_DeleteTensor(loss_vals_tensor);
}
TEST_P(CppGradients, TestMNISTForward2) {
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t X_dims[] = {3, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
// Run the Forward Pass
std::vector<AbstractTensorHandle*> outputs(2);
Status s =
RunModel(MNISTForwardModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[6] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
float tolerance = 1e-3;
for (int j = 0; j < 6; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
TF_Tensor* loss_vals_tensor;
s = GetValue(outputs[1], &loss_vals_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
TF_TensorByteSize(loss_vals_tensor));
float expected_losses[3] = {9.6f, 27.2f, 44.8f};
for (int j = 0; j < 3; j++) {
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(scores_tensor);
TF_DeleteTensor(loss_vals_tensor);
}
TEST_P(CppGradients, TestMatMulTranspose) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t X_dims[] = {2, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
GradientRegistry registry;
// Run the MatMul Op
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[6] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
float tolerance = 1e-3;
for (int j = 0; j < 6; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
}
TEST_P(CppGradients, TestReluGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(X)
* Y = Relu(X)
* outputs = tape.gradient(Y, [X])
*/
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[9] = {0};
memcpy(&result_data[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
float tolerance = 1e-3;
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
}
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
}
TEST_P(CppGradients, TestSoftmaxLossGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = scores
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// y = labels
int y_vals[] = {1, 0, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(X)
* tape.watch(labels)
* loss = SoftmaxLoss(X, labels)
* outputs = tape.gradient(loss, [X, labels])
*
*
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[9] = {0};
memcpy(&result_data[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
float expected_dX[9] = {0.090f, -0.7553f, 0.6652f, -0.9099f, 0.2447f,
0.6652f, 0.8437f, -0.8858f, 0.0420f};
float tolerance = 1e-3;
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
}
// Only Unref() first output as 2nd is nullptr grad for labels
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
}
TEST_P(CppGradients, TestMNISTGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t X_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t y_dims[] = {2};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
// Register Grads
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
*
* tape.watch(W1)
* tape.watch(W2)
* mm = X*W1
* hidden = Relu(mm)
* scores = W2*hidden
* loss = SoftmaxLoss(scores, y)
* outputs = tape.gradient(loss, [A, B])
*
*/
std::vector<AbstractTensorHandle*> outputs(3);
s = RunModel(MNISTGradModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float tolerance = 1e-3;
TF_Tensor* dW1_tensor;
s = GetValue(outputs[0], &dW1_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dW1_tensor),
TF_TensorByteSize(dW1_tensor));
float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
}
TF_Tensor* dW2_tensor;
s = GetValue(outputs[1], &dW2_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(dW2_tensor),
TF_TensorByteSize(dW2_tensor));
float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
outputs[2]->Unref();
TF_DeleteTensor(dW1_tensor);
TF_DeleteTensor(dW2_tensor);
}
TEST_P(CppGradients, TestScalarMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr eta;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
eta.reset(x_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
GradientRegistry registry;
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float tolerance = 1e-3;
float eta_val = 1.5f;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
}
outputs[0]->Unref();
TF_DeleteTensor(dA_tensor);
}
TEST_P(CppGradients, TestMNIST_Training) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t X_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// TODO(amturati): use random initializer for weights instead of
// constant values.
// W1 = first weights
float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t y_dims[] = {2};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
// Register Grads
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Prepare for training
std::vector<AbstractTensorHandle*> weights;
weights.push_back(W1.get());
weights.push_back(W2.get());
// Set learning rate to be 1e-1
AbstractTensorHandle* learning_rate = nullptr;
s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Train
int num_iters = 10;
std::vector<AbstractTensorHandle*> mnist_outputs(3);
std::vector<AbstractTensorHandle*> grads(2);
for (int i = 0; i < num_iters; i++) {
// Run Forward Pass
s = RunModel(MNISTGradModel, ctx.get(),
{X.get(), weights[0], weights[1], y.get()},
absl::MakeSpan(mnist_outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Fill grads
grads[0] = mnist_outputs[0];
grads[1] = mnist_outputs[1];
// Gradient Update
s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
grads[0]->Unref(); // release W1_grad
grads[1]->Unref(); // release W2_grad
mnist_outputs[2]->Unref(); // release loss
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,518 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
using std::vector;
using tensorflow::tracing::TracingOperation;
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
int num_retvals = 1;
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
// one-hot) and records it on the tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractTensorHandle* scores = inputs[0];
AbstractTensorHandle* labels = inputs[1];
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
int num_retvals = 2; // returns loss values and backprop
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
//===================== Test Models to run =========================
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto add_output : add_outputs) {
add_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MatMulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto mm_output : mm_outputs) {
mm_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
// Model to run 2-layer net
Status MNISTForwardModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
/**
* We will trace a 2-layer fully connected network for an MNIST model:
*
* def mnist_forward(X, W1, W2, y_labels):
* mm_out_1 = tf.matmul(X,W1)
* hidden_layer = tf.nn.relu(mm_out_1)
* scores = tf.matmul(hidden_layer,W2)
* softmax =
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
* y_labels)
* return scores, softmax
*
* Use this convention for inputs:
*
* inputs = [X, W1, W2, y_labels]
*
*/
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
absl::MakeSpan(temp_outputs), "relu",
registry)); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0];
outputs[0] = scores;
outputs[1] = loss_vals;
delete tape;
return Status::OK();
}
Status MatMulTransposeModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(X));
tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/true,
/*transpose_b=*/false, registry)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1);
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
"relu0", registry)); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto relu_output : relu_outputs) {
relu_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
Status SoftmaxLossGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/true);
tape->Watch(ToId(X)); // Watch X.
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0", registry));
AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0];
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(
tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)},
/*source_tensor_ids=*/{ToId(W1), ToId(W2)},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
// Only release 2nd temp output as first holds loss values.
temp_outputs[1]->Unref();
outputs[0] = out_grads[0]; // dW1
outputs[1] = out_grads[1]; // dW2
outputs[2] = loss;
delete tape;
return Status::OK();
}
Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* eta = inputs[0];
AbstractTensorHandle* A = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
"scalarMul0", registry)); // Compute eta*A
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* x = inputs[0];
AbstractTensorHandle* y = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
"mul0", registry)); // Compute x*y
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
Status SoftmaxModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* x = inputs[0];
AbstractTensorHandle* labels = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
registry));
outputs[0] = temp_outputs[0]; // loss values
delete tape;
return Status::OK();
}
// ============================= End Models ================================
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` and records it on the tape.
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// ====================== End Tape Ops ============================
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MatMulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes 2-layer Neural Network with Softmax Loss.
Status MNISTForwardModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes MatMul with first matrix tranposed.
Status MatMulTransposeModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify ReluGrad functionality
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify SoftmaxGrad functionality
Status SoftmaxLossGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify Multi-grad functionality for MNIST
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify scalar-tensor multiplication Op
Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
Status SoftmaxModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_

View File

@ -76,10 +76,26 @@ cc_library(
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
],
)
tf_cc_test(
name = "parallel_device_lib_test",
srcs = ["parallel_device_lib_test.cc"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "parallel_device_testlib",
testonly = 1,
@ -87,7 +103,6 @@ cc_library(
hdrs = ["parallel_device_testlib.h"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -102,7 +117,6 @@ tf_cc_test(
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
@ -122,7 +136,6 @@ tf_cc_test(
args = ["--heap_check=local"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
@ -134,19 +147,3 @@ tf_cc_test(
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(
name = "parallel_device_ops_srcs",
srcs = ["parallel_device_ops.cc"],
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
)
cc_library(
name = "parallel_device_ops",
srcs = [":parallel_device_ops_srcs"],
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -136,13 +136,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
}
result.emplace(std::move(outputs));
return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
@ -255,28 +248,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, status);
if (*requested_placement == '\0') {
TF_SetStatus(
status, TF_INTERNAL,
"Ops must be placed on the parallel device explicitly, or their inputs "
"first un-packed. Got an un-placed op with an input placed on the "
"parallel device.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
if (TF_GetCode(status) != TF_OK) return;
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status)));
TFE_TensorHandleDevicePointer(input, status)));
if (TF_GetCode(status) != TF_OK) return;
} else {
typed_inputs.emplace_back(inputs[i]);
typed_inputs.emplace_back(input);
}
}

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
@ -118,6 +119,9 @@ class DeviceThread {
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
// TF_Status is an incomplete type and so can't be stack allocated. To avoid
// unnecessary allocations each Execute call, we keep one heap-allocated
// version for the thread.
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
@ -188,6 +192,9 @@ std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
// Reset the member `status_` so future op executions (after recovery from
// the bad `status`) start with an OK status.
TF_SetStatus(status_.get(), TF_OK, "");
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
@ -255,18 +262,27 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
std::unique_ptr<ParallelTensor> ParallelDevice::Vector(
TFE_Context* context, TF_Status* status,
absl::Span<const int32_t> values) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
if (values.size() != num_underlying_devices()) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"Number of values did not match number of underlying devices.");
return nullptr;
}
for (int device_index = 0; device_index < num_underlying_devices();
++device_index) {
int32_t* device_id = new int32_t;
*device_id = device_index;
int32_t* device_value = new int32_t;
*device_value = values[device_index];
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value,
sizeof(int32_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int32_t*>(data);
@ -295,6 +311,16 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
std::vector<int32_t> ids;
ids.reserve(num_underlying_devices());
for (int i = 0; i < num_underlying_devices(); ++i) {
ids.push_back(i);
}
return Vector(context, status, ids);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
const std::vector<ParallelTensor*>& inputs,
@ -319,21 +345,36 @@ ParallelDevice::Execute(TFE_Context* context,
std::move(device_inputs), attributes,
expected_max_outputs);
}
StatusPtr first_bad_status(nullptr);
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
// We will run every Join even if there are bad statuses in case the user
// wants to recover and continue running ops on the parallel device (which
// would otherwise deadlock).
if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) {
first_bad_status.reset(TF_NewStatus());
TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
TF_Message(status));
}
if (device_index == 0) {
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
if (first_bad_status == nullptr &&
per_device_output_tensors.rbegin()->size() != first_op_output_count) {
first_bad_status.reset(TF_NewStatus());
TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
}
if (first_bad_status != nullptr) {
TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
TF_Message(first_bad_status.get()));
return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
@ -61,6 +62,11 @@ class ParallelDevice {
TFE_TensorHandle* tensor,
TF_Status* status) const;
// Construct a parallel tensor consisting of the scalar values from `values`.
std::unique_ptr<ParallelTensor> Vector(
TFE_Context* context, TF_Status* status,
absl::Span<const int32_t> values) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace parallel_device {
TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<std::string> devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
ParallelDevice parallel_device(std::move(devices));
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
auto outputs =
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
std::vector<ParallelTensor*> handle_inputs;
handle_inputs.reserve(handles.size());
for (auto& handle : handles) {
handle_inputs.push_back(handle.get());
}
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
TFE_OpGetAttrs(read_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
TF_SetStatus(status.get(), TF_OK, "");
// Check that ops still run successfully on the device.
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int32_t>(components[0].get(), 0);
ExpectScalarEq<int32_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}

View File

@ -146,13 +146,16 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
// When running backward functions, builds zeros-like tensors for
// incoming grads which are nullptrs, unless `build_default_zeros_grads`
// is set to false.
Status ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result);
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
bool IsPersistent() const { return persistent_; }
@ -655,8 +658,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
bool build_default_zeros_grads) {
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
@ -717,14 +720,14 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
out_gradients.push_back(nullptr);
zero_indices.push_back(i);
out_gradients.push_back(nullptr);
if (build_default_zeros_grads) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
func_name_it->second.find(i) == func_name_it->second.end()) {
zero_indices.push_back(i);
}
}
} else {
any_gradient_nonzero = true;
@ -745,6 +748,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
}
}
std::vector<Gradient*> in_gradients;
DCHECK(build_default_zeros_grads || zero_indices.empty());
if (any_gradient_nonzero) {
for (const auto i : zero_indices) {
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

View File

@ -35,8 +35,8 @@ using UniquePtrTo_TF_Status =
::std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
Status ModularFileSystem::NewRandomAccessFile(
const std::string& fname,
std::unique_ptr<RandomAccessFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) {
if (ops_->new_random_access_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewRandomAccessFile()"));
@ -55,8 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile(
}
Status ModularFileSystem::NewWritableFile(
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_writable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewWritableFile()"));
@ -75,8 +75,8 @@ Status ModularFileSystem::NewWritableFile(
}
Status ModularFileSystem::NewAppendableFile(
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_appendable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewAppendableFile()"));
@ -95,8 +95,8 @@ Status ModularFileSystem::NewAppendableFile(
}
Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) {
if (ops_->new_read_only_memory_region_from_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname,
@ -116,8 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::FileExists(
const std::string& fname /*, TransactionToken* token */) {
Status ModularFileSystem::FileExists(const std::string& fname,
TransactionToken* token) {
if (ops_->path_exists == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support FileExists()"));
@ -129,9 +129,9 @@ Status ModularFileSystem::FileExists(
return StatusFromTF_Status(plugin_status.get());
}
bool ModularFileSystem::FilesExist(
const std::vector<std::string>& files,
std::vector<Status>* status /*, TransactionToken* token */) {
bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
TransactionToken* token,
std::vector<Status>* status) {
if (ops_->paths_exist == nullptr)
return FileSystem::FilesExist(files, status);
@ -162,9 +162,9 @@ bool ModularFileSystem::FilesExist(
return result;
}
Status ModularFileSystem::GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token */) {
Status ModularFileSystem::GetChildren(const std::string& dir,
TransactionToken* token,
std::vector<std::string>* result) {
if (ops_->get_children == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dir, " does not support GetChildren()"));
@ -188,9 +188,9 @@ Status ModularFileSystem::GetChildren(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>* result /*, TransactionToken* token */) {
Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
TransactionToken* token,
std::vector<std::string>* result) {
if (ops_->get_matching_paths == nullptr)
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
@ -211,8 +211,8 @@ Status ModularFileSystem::GetMatchingPaths(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteFile(
const std::string& fname /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteFile(const std::string& fname,
TransactionToken* token) {
if (ops_->delete_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support DeleteFile()"));
@ -224,9 +224,10 @@ Status ModularFileSystem::DeleteFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
TransactionToken* token,
int64* undeleted_files,
int64* undeleted_dirs) {
if (undeleted_files == nullptr || undeleted_dirs == nullptr)
return errors::FailedPrecondition(
"DeleteRecursively must not be called with `undeleted_files` or "
@ -247,8 +248,8 @@ Status ModularFileSystem::DeleteRecursively(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->delete_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support DeleteDir()"));
@ -260,8 +261,8 @@ Status ModularFileSystem::DeleteDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->recursively_create_dir == nullptr)
return FileSystem::RecursivelyCreateDir(dirname);
@ -272,8 +273,8 @@ Status ModularFileSystem::RecursivelyCreateDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CreateDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::CreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->create_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support CreateDir()"));
@ -285,9 +286,8 @@ Status ModularFileSystem::CreateDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token */) {
Status ModularFileSystem::Stat(const std::string& fname,
TransactionToken* token, FileStatistics* stat) {
if (ops_->stat == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support Stat()"));
@ -310,8 +310,8 @@ Status ModularFileSystem::Stat(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::IsDirectory(
const std::string& name /*, TransactionToken* token */) {
Status ModularFileSystem::IsDirectory(const std::string& name,
TransactionToken* token) {
if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -321,9 +321,9 @@ Status ModularFileSystem::IsDirectory(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token */) {
Status ModularFileSystem::GetFileSize(const std::string& fname,
TransactionToken* token,
uint64* file_size) {
if (ops_->get_file_size == nullptr) {
FileStatistics stat;
Status status = Stat(fname, &stat);
@ -342,9 +342,9 @@ Status ModularFileSystem::GetFileSize(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
Status ModularFileSystem::RenameFile(const std::string& src,
const std::string& target,
TransactionToken* token) {
if (ops_->rename_file == nullptr) {
Status status = CopyFile(src, target);
if (status.ok()) status = DeleteFile(src);
@ -359,9 +359,9 @@ Status ModularFileSystem::RenameFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CopyFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
Status ModularFileSystem::CopyFile(const std::string& src,
const std::string& target,
TransactionToken* token) {
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -372,8 +372,7 @@ Status ModularFileSystem::CopyFile(
return StatusFromTF_Status(plugin_status.get());
}
std::string ModularFileSystem::TranslateName(
const std::string& name /*, TransactionToken* token */) const {
std::string ModularFileSystem::TranslateName(const std::string& name) const {
if (ops_->translate_name == nullptr) return FileSystem::TranslateName(name);
char* p = ops_->translate_name(filesystem_.get(), name.c_str());
@ -385,7 +384,7 @@ std::string ModularFileSystem::TranslateName(
return ret;
}
void ModularFileSystem::FlushCaches(/*TransactionToken* token*/) {
void ModularFileSystem::FlushCaches(TransactionToken* token) {
if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get());
}

View File

@ -59,71 +59,48 @@ class ModularFileSystem final : public FileSystem {
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT;
Status NewRandomAccessFile(
const std::string& fname,
std::unique_ptr<RandomAccessFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewWritableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewAppendableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) override;
Status NewWritableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override;
Status NewReadOnlyMemoryRegionFromFile(
const std::string& fname,
std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token = nullptr */) override;
Status FileExists(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
Status FileExists(const std::string& fname, TransactionToken* token) override;
bool FilesExist(const std::vector<std::string>& files,
std::vector<Status>*
status /*, TransactionToken* token = nullptr */) override;
Status GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token = nullptr */)
override;
Status GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>*
results /*, TransactionToken* token = nullptr */) override;
Status DeleteFile(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token = nullptr */) override;
Status DeleteDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status CreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token = nullptr */) override;
Status IsDirectory(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token = nullptr */) override;
Status RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token = nullptr */)
override;
Status CopyFile(const std::string& src,
const std::string&
target /*, TransactionToken* token = nullptr */) override;
std::string TranslateName(
const std::string& name /*, TransactionToken* token = nullptr */)
const override;
void FlushCaches(/* TransactionToken* token=nullptr */) override;
TransactionToken* token,
std::vector<Status>* status) override;
Status GetChildren(const std::string& dir, TransactionToken* token,
std::vector<std::string>* result) override;
Status GetMatchingPaths(const std::string& pattern, TransactionToken* token,
std::vector<std::string>* results) override;
Status DeleteFile(const std::string& fname, TransactionToken* token) override;
Status DeleteRecursively(const std::string& dirname, TransactionToken* token,
int64* undeleted_files,
int64* undeleted_dirs) override;
Status DeleteDir(const std::string& dirname,
TransactionToken* token) override;
Status RecursivelyCreateDir(const std::string& dirname,
TransactionToken* token) override;
Status CreateDir(const std::string& dirname,
TransactionToken* token) override;
Status Stat(const std::string& fname, TransactionToken* token,
FileStatistics* stat) override;
Status IsDirectory(const std::string& fname,
TransactionToken* token) override;
Status GetFileSize(const std::string& fname, TransactionToken* token,
uint64* file_size) override;
Status RenameFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
Status CopyFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
std::string TranslateName(const std::string& name) const override;
void FlushCaches(TransactionToken* token) override;
private:
std::unique_ptr<TF_Filesystem> filesystem_;

View File

@ -29,10 +29,12 @@ cc_library(
":gcs_helper",
":ram_file_block_cache",
"//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)
@ -58,6 +60,7 @@ cc_library(
deps = [
":cleanup",
"//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",

View File

@ -19,9 +19,11 @@ limitations under the License.
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
@ -119,20 +121,20 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
return -1;
}
int64_t read;
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
&read)) {
auto content_length = stream.headers().find("content-length");
if (content_length == stream.headers().end()) {
// When we read a file with offset that is bigger than the actual file size.
// GCS will return an empty header (e.g no `content-length` header). In this
// case, we will set read to `0` and continue.
if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
read = 0;
} else {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
read = 0;
} else if (!absl::SimpleAtoi(content_length->second, &read)) {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
TF_SetStatus(status, TF_OK, "");
TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset,
read);
stream.read(buffer, read);
read = stream.gcount();
if (read < buffer_size) {
@ -145,6 +147,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
path, " @ ", offset)
.c_str());
}
TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(),
offset);
}
}
return read;
@ -258,7 +262,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
if (*offset == -1 || *offset == 0) {
// UploadFile will automatically switch to resumable upload based on Client
// configuration.
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object);
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object,
gcs::Fields("size"));
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
@ -277,15 +282,18 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
} else {
std::string temporary_object =
gcs::CreateRandomPrefixName("tf_writable_file_gcs");
auto metadata =
gcs_client->UploadFile(outfile->getName(), bucket, temporary_object);
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket,
temporary_object, gcs::Fields(""));
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(),
temporary_object.c_str(), bucket.c_str(), object.c_str());
const std::vector<gcs::ComposeSourceObject> source_objects = {
{object, {}, {}}, {temporary_object, {}, {}}};
metadata = gcs_client->ComposeObject(bucket, source_objects, object);
metadata = gcs_client->ComposeObject(bucket, source_objects, object,
gcs::Fields("size"));
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
@ -320,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n,
"The internal temporary file is not writable.");
return;
}
TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(),
gcs_file->object.c_str(), n);
gcs_file->sync_need = true;
gcs_file->outfile.write(buffer, n);
if (!gcs_file->outfile)
@ -345,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
void Flush(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->sync_need) {
TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (!gcs_file->outfile) {
TF_SetStatus(status, TF_INTERNAL,
"Could not append to the internal temporary file.");
@ -352,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
}
SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
&gcs_file->outfile, gcs_file->gcs_client, status);
TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (TF_GetCode(status) != TF_OK) return;
gcs_file->sync_need = false;
} else {
@ -360,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
}
void Sync(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
Flush(file, status);
}
void Close(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (gcs_file->sync_need) {
Flush(file, status);
}
@ -427,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
max_staleness = value;
}
TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u",
max_bytes, block_size, max_staleness);
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
@ -503,13 +524,18 @@ void Cleanup(TF_Filesystem* filesystem) {
static void UncachedStatForObject(const std::string& bucket,
const std::string& object, GcsFileStat* stat,
gcs::Client* gcs_client, TF_Status* status) {
auto metadata = gcs_client->GetObjectMetadata(bucket, object);
auto metadata = gcs_client->GetObjectMetadata(
bucket, object, gcs::Fields("generation,size,timeStorageClassUpdated"));
if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status);
stat->generation_number = metadata->generation();
stat->base.length = metadata->size();
stat->base.mtime_nsec =
metadata->time_storage_class_updated().time_since_epoch().count();
stat->base.is_directory = object.back() == '/';
TF_VLog(1,
"Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;",
bucket.c_str(), object.c_str(), stat->base.length,
stat->generation_number, stat->base.mtime_nsec);
return TF_SetStatus(status, TF_OK, "");
}
@ -544,9 +570,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
if (TF_GetCode(status) != TF_OK) return -1;
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
path, stat.generation_number)) {
std::cout
<< "File signature has been changed. Refreshing the cache. Path: "
<< path;
TF_VLog(
1,
"File signature has been changed. Refreshing the cache. Path: %s",
path.c_str());
}
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
} else {
@ -578,6 +605,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
(gcs_file->compose ? 0 : -1)});
// We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name);
TF_VLog(3, "GcsWritableFile: %s", path);
TF_SetStatus(status, TF_OK, "");
}
@ -607,7 +635,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
} else {
// If compose is true, we do not download anything.
// Instead we only check if this file exists on server or not.
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object,
gcs::Fields("size"));
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_OK) {
file->plugin_file = new tf_writable_file::GCSFile(
@ -623,7 +652,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
return;
}
}
TF_VLog(3, "GcsWritableFile: %s with existing file %s", path,
temp_file_name.c_str());
TF_SetStatus(status, TF_OK, "");
}
@ -638,7 +668,8 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object,
gcs::Fields("size"));
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
@ -663,28 +694,190 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
}
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
static void StatForObject(GCSFile* gcs_file, const std::string& path,
const std::string& bucket, const std::string& object,
GcsFileStat* stat, TF_Status* status) {
if (object.empty())
return TF_SetStatus(
status, TF_INVALID_ARGUMENT,
absl::StrCat("'object' must be a non-empty string. (File: ", path, ")")
.c_str());
TF_SetStatus(status, TF_OK, "");
gcs_file->stat_cache->LookupOrCompute(
path, stat,
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
TF_Status* status) {
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
status);
},
status);
}
static bool ObjectExists(GCSFile* gcs_file, const std::string& path,
const std::string& bucket, const std::string& object,
TF_Status* status) {
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return !stat.base.is_directory;
}
static bool BucketExists(GCSFile* gcs_file, const std::string& bucket,
TF_Status* status) {
auto metadata =
gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields(""));
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static std::vector<std::string> GetChildrenBounded(
GCSFile* gcs_file, std::string dir, uint64_t max_results, bool recursive,
bool include_self_directory_marker, TF_Status* status) {
std::string bucket, prefix;
MaybeAppendSlash(&dir);
ParseGCSPath(dir, true, &bucket, &prefix, status);
std::vector<std::string> result;
uint64_t count = 0;
std::string delimiter = recursive ? "" : "/";
for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes(
bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter),
gcs::Fields("items(name),prefixes"))) {
if (count == max_results) {
TF_SetStatus(status, TF_OK, "");
return result;
}
if (!item) {
TF_SetStatusFromGCSStatus(item.status(), status);
return result;
}
auto value = *std::move(item);
std::string children = absl::holds_alternative<std::string>(value)
? absl::get<std::string>(value)
: absl::get<gcs::ObjectMetadata>(value).name();
auto pos = children.find(prefix);
if (pos != 0) {
TF_SetStatus(status, TF_INTERNAL,
absl::StrCat("Unexpected response: the returned file name ",
children, " doesn't match the prefix ", prefix)
.c_str());
return result;
}
children.erase(0, prefix.length());
if (!children.empty() || include_self_directory_marker) {
result.emplace_back(children);
}
++count;
}
return result;
}
static bool FolderExists(GCSFile* gcs_file, std::string dir,
TF_Status* status) {
ExpiringLRUCache<GcsFileStat>::ComputeFunc compute_func =
[gcs_file](const std::string& dir, GcsFileStat* stat, TF_Status* status) {
auto children =
GetChildrenBounded(gcs_file, dir, 1, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
if (!children.empty()) {
stat->base = {0, 0, true};
return TF_SetStatus(status, TF_OK, "");
} else {
return TF_SetStatus(status, TF_INVALID_ARGUMENT, "Not a directory!");
}
};
GcsFileStat stat;
MaybeAppendSlash(&dir);
gcs_file->stat_cache->LookupOrCompute(dir, &stat, compute_func, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_INVALID_ARGUMENT)
return false;
if (TF_GetCode(status) == TF_INVALID_ARGUMENT) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static void ClearFileCaches(GCSFile* gcs_file, const std::string& path) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->RemoveFile(path);
gcs_file->stat_cache->Delete(path);
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
bool result = BucketExists(gcs_file, bucket, status);
if (result) return TF_SetStatus(status, TF_OK, "");
}
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_NOT_FOUND) return;
bool result = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK || (TF_GetCode(status) == TF_OK && result))
return;
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string dir = path;
MaybeAppendSlash(&dir);
TF_VLog(3,
"CreateDir: creating directory with path: %s and "
"path_with_slash: %s",
path, dir.c_str());
std::string bucket, object;
ParseGCSPath(dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
bool is_directory = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return;
if (!is_directory)
TF_SetStatus(status, TF_NOT_FOUND,
absl::StrCat("The specified bucket ", dir, " was not found.")
.c_str());
return;
}
MaybeAppendSlash(&object);
auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
TF_SetStatusFromGCSStatus(object_metadata.status(), status);
if (TF_GetCode(status) == TF_NOT_FOUND) {
auto insert_metadata =
gcs_file->gcs_client.InsertObject(bucket, object, "");
TF_SetStatusFromGCSStatus(insert_metadata.status(), status);
} else if (TF_GetCode(status) == TF_OK) {
TF_SetStatus(status, TF_ALREADY_EXISTS, path);
PathExists(filesystem, dir.c_str(), status);
if (TF_GetCode(status) == TF_OK) {
// Use the original name for a correct error here.
TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path);
return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
auto metadata = gcs_file->gcs_client.InsertObject(
bucket, object, "",
// Adding this parameter means HTTP_CODE_PRECONDITION_FAILED
// will be returned if the object already exists, so avoid reuploading.
gcs::IfGenerationMatch(0), gcs::Fields(""));
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
// TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the
@ -700,79 +893,31 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
if (TF_GetCode(status) == TF_OK) ClearFileCaches(gcs_file, path);
}
// Checks that the directory is empty (i.e no objects with this prefix exist).
// Deletes the GCS directory marker if it exists.
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
MaybeAppendSlash(&object);
// A directory is considered empty either if there are no matching objects
// with the corresponding name prefix or if there is exactly one matching
// object and it is the directory marker. Therefore we need to retrieve
// at most two children for the prefix to detect if a directory is empty.
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
int object_count = 0;
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
++object_count;
// We consider a path is a non-empty directory in two cases:
// - There are more than two objects whose keys start with the name of this
// directory.
// - There is one object whose key contains the name of this directory ( but
// not equal ).
if (object_count > 1 || metadata->name() != object) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
return;
}
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
}
// TODO(vnvo2409): `DeleteRecursively` needs `GetChildrens` but there will be
// some differents compared to the default implementation. Will be refactored.
static void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files,
uint64_t* undeleted_dirs, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
auto childrens = GetChildrenBounded(gcs_file, path, 2, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs::DeleteByPrefix(gcs_file->gcs_client, bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
if (TF_GetCode(status) != TF_OK) return;
*undeleted_dirs = 0;
*undeleted_files = 0;
}
// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND`
// if the object does not exist. In that case, we will have to check if the
// `src` is a directory or not to set the correspondent `status` (i.e
// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if
// path `src` is a directory).
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (childrens.size() > 1 || (childrens.size() == 1 && !childrens[0].empty()))
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
if (childrens.size() == 1 && childrens[0].empty()) {
// This is the directory marker object. Delete it.
std::string dir = path;
MaybeAppendSlash(&dir);
DeleteFile(filesystem, dir.c_str(), status);
return;
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src);
TF_SetStatusFromGCSStatus(gcs_status, status);
TF_SetStatus(status, TF_OK, "");
}
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
@ -787,35 +932,11 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
bucket_src, object_src, bucket_dst, object_dst,
gcs::Fields("done,rewriteToken"));
TF_SetStatusFromGCSStatus(metadata.status(), status);
}
// TODO(vnvo2409): This approach can cause a problem when our path is
// `path/to/dir` and there is an object with key `path/to/directory`. Will be
// fixed when refactoring.
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
// We consider a path exists if there is at least one object whose key
// contains the path.
return TF_SetStatus(status, TF_OK, "");
}
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
@ -824,41 +945,133 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK)
return true;
else
return false;
bool result = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return false;
if (!result)
TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The specified bucket gs://", bucket, " was not found.")
.c_str());
return result;
}
// We check if there is an object with this key on the GCS server.
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
if (metadata) {
TF_SetStatus(status, TF_OK, "");
if (metadata->name().back() == '/')
return true;
else
return false;
}
bool is_folder = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_folder) return true;
// If there is no object with this key on the GCS server. We check if there is
// any object whose key contains that path.
MaybeAppendSlash(&object);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return false;
}
TF_SetStatus(status, TF_OK, "");
return true;
bool is_object = ObjectExists(gcs_file, path, bucket, object, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_object) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
absl::StrCat("The specified path ", path, " is not a directory.")
.c_str());
return false;
}
TF_SetStatus(status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
return false;
}
static void RenameObject(const TF_Filesystem* filesystem,
const std::string& src, const std::string& dst,
TF_Status* status) {
TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str());
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst,
gcs::Fields("done,rewriteToken"));
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK) return;
TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str());
ClearFileCaches(gcs_file, dst);
DeleteFile(filesystem, src.c_str(), status);
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
if (!IsDirectory(filesystem, src, status)) {
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) {
TF_SetStatus(status, TF_OK, "");
RenameObject(filesystem, src, dst, status);
}
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, src, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string src_dir = src;
std::string dst_dir = dst;
MaybeAppendSlash(&src_dir);
MaybeAppendSlash(&dst_dir);
for (const std::string& children : childrens) {
RenameObject(filesystem, src_dir + children, dst_dir + children, status);
if (TF_GetCode(status) != TF_OK) return;
}
TF_SetStatus(status, TF_OK, "");
}
void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files, uint64_t* undeleted_dirs,
TF_Status* status) {
if (!undeleted_files || !undeleted_dirs)
return TF_SetStatus(
status, TF_INTERNAL,
"'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
*undeleted_files = 0;
*undeleted_dirs = 0;
if (!IsDirectory(filesystem, path, status)) {
*undeleted_dirs = 1;
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string dir = path;
MaybeAppendSlash(&dir);
for (const std::string& children : childrens) {
const std::string& full_path = dir + children;
DeleteFile(filesystem, full_path.c_str(), status);
if (TF_GetCode(status) != TF_OK) {
if (IsDirectory(filesystem, full_path.c_str(), status))
// The object is a directory marker.
(*undeleted_dirs)++;
else
(*undeleted_files)++;
}
}
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, false, false, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = childrens.size();
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++)
(*entries)[i] = strdup(childrens[i].c_str());
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
std::string bucket, object;
@ -867,7 +1080,8 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
auto bucket_metadata =
gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields(""));
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK) {
stats->is_directory = true;
@ -882,8 +1096,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
stats->mtime_nsec = 0;
return TF_SetStatus(status, TF_OK, "");
}
if (TF_GetCode(status) == TF_OK) {
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) {
auto metadata = gcs_file->gcs_client.GetObjectMetadata(
bucket, object, gcs::Fields("size,timeStorageClassUpdated"));
if (metadata) {
stats->is_directory = false;
stats->length = metadata.value().size();
@ -896,6 +1111,29 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
}
}
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
// Only validate the name.
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return -1;
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
return stat.length;
}
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
static void FlushCaches(const TF_Filesystem* filesystem) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->Flush();
gcs_file->stat_cache->Clear();
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -912,6 +1150,13 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
@ -921,6 +1166,20 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_gcs_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_gcs_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_gcs_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_gcs_filesystem::DeleteDir;
ops->filesystem_ops->delete_recursively =
tf_gcs_filesystem::DeleteRecursively;
ops->filesystem_ops->copy_file = tf_gcs_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_gcs_filesystem::PathExists;
ops->filesystem_ops->is_directory = tf_gcs_filesystem::IsDirectory;
ops->filesystem_ops->stat = tf_gcs_filesystem::Stat;
ops->filesystem_ops->get_children = tf_gcs_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_gcs_filesystem::TranslateName;
ops->filesystem_ops->flush_caches = tf_gcs_filesystem::FlushCaches;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -87,6 +87,24 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status);
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status);
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status);
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_Status* status);
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status);
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890";
// We will work with content_view instead of content.
@ -94,6 +95,70 @@ class GCSFilesystemTest : public ::testing::Test {
return translated_name;
}
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
GetWriter() {
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
new TF_WritableFile, [](TF_WritableFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
delete file;
}
});
writer->plugin_file = nullptr;
return writer;
}
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
GetReader() {
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
reader->plugin_file = nullptr;
return reader;
}
void WriteString(const std::string& path, const std::string& content) {
auto writer = GetWriter();
tf_gcs_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Append(writer.get(), content.c_str(), content.length(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Close(writer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return;
}
std::string ReadAll(const std::string& path) {
auto reader = GetReader();
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
auto file_size =
tf_gcs_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
std::string content;
content.resize(file_size);
auto read = tf_random_access_file::Read(reader.get(), 0, file_size,
&content[0], status_);
if (TF_GetCode(status_) != TF_OK) return "";
if (read >= 0) content.resize(read);
if (file_size != content.size())
TF_SetStatus(
status_, TF_DATA_LOSS,
std::string("expected " + std::to_string(file_size) + " got " +
std::to_string(content.size()) + " bytes")
.c_str());
return content;
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
@ -326,6 +391,145 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
delete region;
}
TEST_F(GCSFilesystemTest, PathExists) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string path = GetURIForPath("PathExists");
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_);
TF_SetStatus(status_, TF_OK, "");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(GCSFilesystemTest, GetChildren) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string base = GetURIForPath("GetChildren");
tf_gcs_filesystem::CreateDir(filesystem_, base.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string file = io::JoinPath(base, "TestFile.csv");
WriteString(file, "test");
EXPECT_TF_OK(status_);
const std::string subdir = io::JoinPath(base, "SubDir");
tf_gcs_filesystem::CreateDir(filesystem_, subdir.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv");
WriteString(subfile, "test");
EXPECT_TF_OK(status_);
char** entries;
auto num_entries = tf_gcs_filesystem::GetChildren(filesystem_, base.c_str(),
&entries, status_);
EXPECT_TF_OK(status_);
std::vector<std::string> childrens;
for (int i = 0; i < num_entries; ++i) {
childrens.push_back(entries[i]);
}
std::sort(childrens.begin(), childrens.end());
EXPECT_EQ(std::vector<string>({"SubDir/", "TestFile.csv"}), childrens);
}
TEST_F(GCSFilesystemTest, DeleteFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string path = GetURIForPath("DeleteFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND);
}
TEST_F(GCSFilesystemTest, CreateDir) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string dir = GetURIForPath("CreateDir");
tf_gcs_filesystem::CreateDir(filesystem_, dir.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_TRUE(stat.is_directory);
}
TEST_F(GCSFilesystemTest, DeleteDir) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string dir = GetURIForPath("DeleteDir");
const std::string file = io::JoinPath(dir, "DeleteDirFile.csv");
WriteString(file, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
EXPECT_EQ(TF_GetCode(status_), TF_FAILED_PRECONDITION);
TF_SetStatus(status_, TF_OK, "");
tf_gcs_filesystem::DeleteFile(filesystem_, file.c_str(), status_);
EXPECT_TF_OK(status_);
tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_);
EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
}
TEST_F(GCSFilesystemTest, StatFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string path = GetURIForPath("StatFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
TF_FileStatistics stat;
tf_gcs_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, stat.length);
EXPECT_FALSE(stat.is_directory);
}
TEST_F(GCSFilesystemTest, RenameFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string src = GetURIForPath("RenameFileSrc");
const std::string dst = GetURIForPath("RenameFileDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test", result);
}
TEST_F(GCSFilesystemTest, RenameFileOverwrite) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_);
const std::string src = GetURIForPath("RenameFileOverwriteSrc");
const std::string dst = GetURIForPath("RenameFileOverwriteDst");
WriteString(src, "test_old");
ASSERT_TF_OK(status_);
WriteString(dst, "test_new");
ASSERT_TF_OK(status_);
tf_gcs_filesystem::PathExists(filesystem_, dst.c_str(), status_);
EXPECT_TF_OK(status_);
tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test_old", result);
}
// These tests below are ported from
// `//tensorflow/core/platform/cloud:gcs_file_system_test`
TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) {

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
namespace tf_gcs_filesystem {
@ -65,8 +66,8 @@ class RamFileBlockCache {
pruning_thread_.reset(
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
}
std::cout << "GCS file block cache is "
<< (IsCacheEnabled() ? "enabled" : "disabled") << ".\n";
TF_VLog(1, "GCS file block cache is %s.\n",
(IsCacheEnabled() ? "enabled" : "disabled"));
}
~RamFileBlockCache() {

View File

@ -1,5 +1,5 @@
# Experimental hadoop filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
@ -33,3 +33,38 @@ cc_library(
"@com_google_absl//absl/synchronization",
],
)
# This test is set to manual because it requires downloading the Hadoop
# distribution to run. To run this test:
# 1. Ensure $JAVA_HOME is set to the location of a JDK 8 installation.
# 2. Download the binary Hadoop distribution from:
# http://hadoop.apache.org/releases.html
# 3. Extract the Hadoop distribution and run:
# source libexec/hadoop-config.sh
# 4. Optionally set up HDFS cluster configurations (optionally Kerberos) within
# $HADOOP_HDFS_HOME/etc/hadoop if you want to test against real
# distributed HDFS cluster
# 5. bazel test \
# --test_env=LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/amd64/server \
# --test_env=HADOOP_HDFS_HOME=$HADOOP_HDFS_HOME \
# --test_env=CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) \
# :hadoop_file_system_test
# To test against the real distributed cluster, add the following option for
# bazel test:
# --test_env=HADOOP_TEST_TMPDIR=hdfs://cluster/test/tmp/dir
tf_cc_test(
name = "hadoop_filesystem_test",
srcs = [
"hadoop_filesystem_test.cc",
],
tags = [
"manual",
"notap",
],
deps = [
":hadoop_filesystem_impl",
"//tensorflow/core/platform:path",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
],
)

View File

@ -37,11 +37,17 @@ static void plugin_memory_free(void* ptr) { free(ptr); }
void ParseHadoopPath(const std::string& fname, std::string* scheme,
std::string* namenode, std::string* path) {
size_t scheme_end = fname.find("://") + 2;
*scheme = fname.substr(0, scheme_end + 1);
// We don't want `://` in scheme.
*scheme = fname.substr(0, scheme_end - 2);
size_t nn_end = fname.find("/", scheme_end + 1);
if (nn_end == std::string::npos) return;
if (nn_end == std::string::npos) {
*namenode = fname.substr(scheme_end + 1);
*path = "";
return;
}
*namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1);
*path = fname.substr(nn_end + 1);
// We keep `/` in path.
*path = fname.substr(nn_end);
}
void SplitArchiveNameAndPath(std::string* path, std::string* nn,
@ -54,7 +60,7 @@ void SplitArchiveNameAndPath(std::string* path, std::string* nn,
}
// Case of hadoop archive. Namenode is the path to the archive.
std::ostringstream namenodestream;
namenodestream << "har://" << nn
namenodestream << "har://" << *nn
<< path->substr(0, index_end_archive_name + 4);
*nn = namenodestream.str();
path->erase(0, index_end_archive_name + 4);
@ -247,8 +253,8 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* dst = buffer;
bool eof_retried = false;
int64_t r = 0;
while (TF_GetCode(status) == TF_OK && !eof_retried) {
int64_t read = 0;
while (TF_GetCode(status) == TF_OK && n > 0) {
// We lock inside the loop rather than outside so we don't block other
// concurrent readers.
absl::MutexLock l(&hdfs_file->mu);
@ -257,12 +263,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
// of int32. -2 offset can avoid JVM OutOfMemoryError.
size_t read_n =
(std::min)(n, static_cast<size_t>(std::numeric_limits<int>::max() - 2));
r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset), dst,
static_cast<tSize>(read_n));
int64_t r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset),
dst, static_cast<tSize>(read_n));
if (r > 0) {
dst += r;
n -= r;
offset += r;
read += r;
} else if (!eof_retried && r == 0) {
// Always reopen the file upon reaching EOF to see if there's more data.
// If writers are streaming contents while others are concurrently
@ -274,11 +281,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0);
if (handle == nullptr) {
hdfs_file->handle =
libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0);
if (hdfs_file->handle == nullptr) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
handle = hdfs_file->handle;
eof_retried = true;
} else if (eof_retried && r == 0) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
@ -288,7 +297,7 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
TF_SetStatusFromIOError(status, errno, path);
}
}
return r;
return read;
}
} // namespace tf_random_access_file
@ -308,7 +317,7 @@ typedef struct HDFSFile {
handle(handle) {}
} HDFSFile;
static void Cleanup(TF_WritableFile* file) {
void Cleanup(TF_WritableFile* file) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
hdfs_file->fs = nullptr;
@ -433,6 +442,23 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_WRONLY, 0, 0, 0);
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
file->plugin_file =
new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(),
O_WRONLY | O_APPEND, 0, 0, 0);
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
@ -442,6 +468,202 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_SetStatus(status, TF_OK, "");
}
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status) {
// hadoopReadZero() technically supports this call with the following
// caveats:
// - It only works up to 2 GB. We'd have to Stat() the file to ensure that
// it fits.
// - If not on the local filesystem, the entire file will be read, making
// it inefficient for callers that assume typical mmap() behavior.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"HDFS does not support ReadOnlyMemoryRegion");
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsExists(fs, hdfs_path.c_str()) == 0)
TF_SetStatus(status, TF_OK, "");
else
TF_SetStatus(status, TF_NOT_FOUND,
(std::string(path) + " not found").c_str());
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) return TF_SetStatusFromIOError(status, errno, path);
stats->length = static_cast<int64_t>(info->mSize);
stats->mtime_nsec = static_cast<int64_t>(info->mLastMod) * 1e9;
stats->is_directory = info->mKind == kObjectKindDirectory;
libhdfs->hdfsFreeFileInfo(info, 1);
TF_SetStatus(status, TF_OK, "");
}
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
TF_SetStatus(status, TF_OK, "");
auto size = static_cast<int64_t>(info->mSize);
libhdfs->hdfsFreeFileInfo(info, 1);
return size;
}
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/0) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsCreateDirectory(fs, hdfs_path.c_str()) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// Count the number of entries in the directory, and only delete if it's
// non-empty. This is consistent with the interface, but note that there's
// a race condition where a file may be added after this check, in which
// case the directory will still be deleted.
int entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &entries);
if (info != nullptr) libhdfs->hdfsFreeFileInfo(info, entries);
// Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty
// folder, especially for Kerberos enable setup, EAGAIN is quite common when
// the call is actually successful. Check again by Stat.
if (info == nullptr && errno != 0) {
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return;
}
if (entries > 0)
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/1) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
ParseHadoopPath(src, &scheme, &namenode, &hdfs_path_src);
ParseHadoopPath(dst, &scheme, &namenode, &hdfs_path_dst);
if (libhdfs->hdfsExists(fs, hdfs_path_dst.c_str()) == 0 &&
libhdfs->hdfsDelete(fs, hdfs_path_dst.c_str(), /*recursive=*/0) != 0)
return TF_SetStatusFromIOError(status, errno, dst);
if (libhdfs->hdfsRename(fs, hdfs_path_src.c_str(), hdfs_path_dst.c_str()) !=
0)
TF_SetStatusFromIOError(status, errno, src);
else
TF_SetStatus(status, TF_OK, "");
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// hdfsListDirectory returns nullptr if the directory is empty. Do a separate
// check to verify the directory exists first.
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &num_entries);
if (info == nullptr) {
if (stat.is_directory) {
// Assume it's an empty directory.
TF_SetStatus(status, TF_OK, "");
return 0;
}
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
auto BaseName = [](const std::string& name) {
return name.substr(name.find_last_of('/') + 1);
};
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(BaseName(info[i].mName).c_str());
}
libhdfs->hdfsFreeFileInfo(info, num_entries);
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
// TODO(vnvo2409): Implement later
} // namespace tf_hadoop_filesystem

View File

@ -15,7 +15,62 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#include <string>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseHadoopPath(const std::string& fname, std::string* scheme,
std::string* namenode, std::string* path);
void SplitArchiveNameAndPath(std::string* path, std::string* nn,
TF_Status* status);
class LibHDFS;
namespace tf_random_access_file {
void Cleanup(TF_RandomAccessFile* file);
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status);
} // namespace tf_random_access_file
namespace tf_writable_file {
void Cleanup(TF_WritableFile* file);
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status);
int64_t Tell(const TF_WritableFile* file, TF_Status* status);
void Sync(const TF_WritableFile* file, TF_Status* status);
void Flush(const TF_WritableFile* file, TF_Status* status);
void Close(const TF_WritableFile* file, TF_Status* status);
} // namespace tf_writable_file
namespace tf_hadoop_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status);
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status);
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status);
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status);
} // namespace tf_hadoop_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_

View File

@ -0,0 +1,418 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#include "third_party/hadoop/hdfs.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
namespace tensorflow {
namespace {
class HadoopFileSystemTest : public ::testing::Test {
public:
void SetUp() override {
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_hadoop_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
}
void TearDown() override {
TF_DeleteStatus(status_);
tf_hadoop_filesystem::Cleanup(filesystem_);
delete filesystem_;
}
std::string TmpDir(const std::string& path) {
char* test_dir = getenv("HADOOP_TEST_TMPDIR");
if (test_dir != nullptr) {
return io::JoinPath(std::string(test_dir), path);
} else {
return "file://" + io::JoinPath(testing::TmpDir(), path);
}
}
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
GetWriter() {
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
new TF_WritableFile, [](TF_WritableFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
delete file;
}
});
writer->plugin_file = nullptr;
return writer;
}
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
GetReader() {
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
reader->plugin_file = nullptr;
return reader;
}
void WriteString(const std::string& path, const std::string& content) {
auto writer = GetWriter();
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(),
writer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Append(writer.get(), content.c_str(), content.length(),
status_);
if (TF_GetCode(status_) != TF_OK) return;
tf_writable_file::Close(writer.get(), status_);
if (TF_GetCode(status_) != TF_OK) return;
}
std::string ReadAll(const std::string& path) {
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
auto file_size =
tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
if (TF_GetCode(status_) != TF_OK) return "";
std::string content;
content.resize(file_size);
auto read = tf_random_access_file::Read(reader.get(), 0, file_size,
&content[0], status_);
if (TF_GetCode(status_) != TF_OK) return "";
if (read >= 0) content.resize(read);
if (file_size != content.size())
TF_SetStatus(
status_, TF_DATA_LOSS,
std::string("expected " + std::to_string(file_size) + " got " +
std::to_string(content.size()) + " bytes")
.c_str());
return content;
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
};
TEST_F(HadoopFileSystemTest, RandomAccessFile) {
const std::string path = TmpDir("RandomAccessFile");
const std::string content = "abcdefghijklmn";
WriteString(path, content);
ASSERT_TF_OK(status_);
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content.size(), result.size());
EXPECT_EQ(content, result);
result.clear();
result.resize(4);
read = tf_random_access_file::Read(reader.get(), 2, 4, &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, result.size());
EXPECT_EQ(content.substr(2, 4), result);
}
TEST_F(HadoopFileSystemTest, WritableFile) {
auto writer = GetWriter();
const std::string path = TmpDir("WritableFile");
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Append(writer.get(), "content1,", strlen("content1,"),
status_);
EXPECT_TF_OK(status_);
auto pos = tf_writable_file::Tell(writer.get(), status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(pos, 9);
tf_writable_file::Append(writer.get(), "content2", strlen("content2"),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Flush(writer.get(), status_);
EXPECT_TF_OK(status_);
tf_writable_file::Sync(writer.get(), status_);
EXPECT_TF_OK(status_);
tf_writable_file::Close(writer.get(), status_);
EXPECT_TF_OK(status_);
auto content = ReadAll(path);
EXPECT_TF_OK(status_);
EXPECT_EQ("content1,content2", content);
}
TEST_F(HadoopFileSystemTest, PathExists) {
const std::string path = TmpDir("PathExists");
tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_);
TF_SetStatus(status_, TF_OK, "");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(HadoopFileSystemTest, GetChildren) {
const std::string base = TmpDir("GetChildren");
tf_hadoop_filesystem::CreateDir(filesystem_, base.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string file = io::JoinPath(base, "TestFile.csv");
WriteString(file, "test");
EXPECT_TF_OK(status_);
const std::string subdir = io::JoinPath(base, "SubDir");
tf_hadoop_filesystem::CreateDir(filesystem_, subdir.c_str(), status_);
EXPECT_TF_OK(status_);
const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv");
WriteString(subfile, "test");
EXPECT_TF_OK(status_);
char** entries;
auto num_entries = tf_hadoop_filesystem::GetChildren(
filesystem_, base.c_str(), &entries, status_);
EXPECT_TF_OK(status_);
std::vector<std::string> childrens;
for (int i = 0; i < num_entries; ++i) {
childrens.push_back(entries[i]);
}
std::sort(childrens.begin(), childrens.end());
EXPECT_EQ(std::vector<string>({"SubDir", "TestFile.csv"}), childrens);
}
TEST_F(HadoopFileSystemTest, DeleteFile) {
const std::string path = TmpDir("DeleteFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(HadoopFileSystemTest, GetFileSize) {
const std::string path = TmpDir("GetFileSize");
WriteString(path, "test");
ASSERT_TF_OK(status_);
auto file_size =
tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, file_size);
}
TEST_F(HadoopFileSystemTest, CreateDirStat) {
const std::string path = TmpDir("CreateDirStat");
tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_TRUE(stat.is_directory);
}
TEST_F(HadoopFileSystemTest, DeleteDir) {
const std::string path = TmpDir("DeleteDir");
tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_);
EXPECT_NE(TF_GetCode(status_), TF_OK);
tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
TF_FileStatistics stat;
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_NE(TF_GetCode(status_), TF_OK);
}
TEST_F(HadoopFileSystemTest, RenameFile) {
const std::string src = TmpDir("RenameFileSrc");
const std::string dst = TmpDir("RenameFileDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(),
status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test", result);
}
TEST_F(HadoopFileSystemTest, RenameFileOverwrite) {
const std::string src = TmpDir("RenameFileOverwriteSrc");
const std::string dst = TmpDir("RenameFileOverwriteDst");
WriteString(src, "test_old");
ASSERT_TF_OK(status_);
WriteString(dst, "test_new");
ASSERT_TF_OK(status_);
tf_hadoop_filesystem::PathExists(filesystem_, dst.c_str(), status_);
EXPECT_TF_OK(status_);
tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(),
status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test_old", result);
}
TEST_F(HadoopFileSystemTest, StatFile) {
const std::string path = TmpDir("StatFile");
WriteString(path, "test");
ASSERT_TF_OK(status_);
TF_FileStatistics stat;
tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(4, stat.length);
EXPECT_FALSE(stat.is_directory);
}
TEST_F(HadoopFileSystemTest, WriteWhileReading) {
const std::string path = TmpDir("WriteWhileReading");
// Skip the test if we're not testing on HDFS. Hadoop's local filesystem
// implementation makes no guarantees that writable files are readable while
// being written.
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
auto writer = GetWriter();
tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(),
status_);
EXPECT_TF_OK(status_);
const std::string content1 = "content1";
tf_writable_file::Append(writer.get(), content1.c_str(), content1.size(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Flush(writer.get(), status_);
EXPECT_TF_OK(status_);
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content1.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content1, result);
const std::string content2 = "content2";
tf_writable_file::Append(writer.get(), content2.c_str(), content2.size(),
status_);
EXPECT_TF_OK(status_);
tf_writable_file::Flush(writer.get(), status_);
EXPECT_TF_OK(status_);
result.resize(content2.size());
read = tf_random_access_file::Read(reader.get(), content1.size(),
content2.size(), &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content2, result);
tf_writable_file::Close(writer.get(), status_);
EXPECT_TF_OK(status_);
}
TEST_F(HadoopFileSystemTest, HarSplit) {
const std::string har_path =
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
std::string scheme, namenode, path;
ParseHadoopPath(har_path, &scheme, &namenode, &path);
EXPECT_EQ("har", scheme);
EXPECT_EQ("hdfs-root", namenode);
EXPECT_EQ("/user/j.doe/my_archive.har/dir0/dir1/file.txt", path);
SplitArchiveNameAndPath(&path, &namenode, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode);
EXPECT_EQ("/dir0/dir1/file.txt", path);
}
TEST_F(HadoopFileSystemTest, NoHarExtension) {
const std::string har_path =
"har://hdfs-root/user/j.doe/my_archive/dir0/dir1/file.txt";
std::string scheme, namenode, path;
ParseHadoopPath(har_path, &scheme, &namenode, &path);
EXPECT_EQ("har", scheme);
EXPECT_EQ("hdfs-root", namenode);
EXPECT_EQ("/user/j.doe/my_archive/dir0/dir1/file.txt", path);
SplitArchiveNameAndPath(&path, &namenode, status_);
EXPECT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT) << TF_Message(status_);
}
TEST_F(HadoopFileSystemTest, HarRootPath) {
const std::string har_path = "har://hdfs-root/user/j.doe/my_archive.har";
std::string scheme, namenode, path;
ParseHadoopPath(har_path, &scheme, &namenode, &path);
EXPECT_EQ("har", scheme);
EXPECT_EQ("hdfs-root", namenode);
EXPECT_EQ("/user/j.doe/my_archive.har", path);
SplitArchiveNameAndPath(&path, &namenode, status_);
EXPECT_TF_OK(status_);
EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode);
EXPECT_EQ("/", path);
}
TEST_F(HadoopFileSystemTest, WriteLargeFile) {
if (std::getenv("HADOOP_TEST_LARGE_FILE") != "1") GTEST_SKIP();
const std::string path = TmpDir("WriteLargeFile");
const size_t file_size =
static_cast<size_t>(std::numeric_limits<tSize>::max()) + 1024;
// Fake a test string.
std::string source(file_size, {});
for (size_t i = 0; i < file_size; ++i) source[i] = (i % 128);
WriteString(path, source);
ASSERT_TF_OK(status_);
auto result = ReadAll(path);
EXPECT_TF_OK(status_);
EXPECT_EQ(source, result);
}
// NewAppendableFile() is not testable. Local filesystem maps to
// ChecksumFileSystem in Hadoop, where appending is an unsupported operation.
} // namespace
} // namespace tensorflow
GTEST_API_ int main(int argc, char** argv) {
tensorflow::testing::InstallStacktraceHandler();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -26,6 +26,8 @@ cc_library(
}),
deps = [
":aws_crypto",
":aws_logging",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@aws",
@ -45,6 +47,18 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "aws_logging",
srcs = ["aws_logging.cc"],
hdrs = ["aws_logging.h"],
deps = [
"//tensorflow/c:logging",
"@aws",
"@com_google_absl//absl/synchronization",
],
alwayslink = 1,
)
tf_cc_test(
name = "s3_filesystem_test",
srcs = [

View File

@ -0,0 +1,159 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h"
#include <aws/core/Aws.h>
#include <aws/core/utils/logging/AWSLogging.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <cstdarg>
#include <cstdio>
#include <sstream>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/logging.h"
static constexpr char kAWSLoggingTag[] = "AWSLogging";
static const std::map<const std::string, const Aws::Utils::Logging::LogLevel>
log_levels_string_to_aws = {
{"off", Aws::Utils::Logging::LogLevel::Off},
{"fatal", Aws::Utils::Logging::LogLevel::Fatal},
{"error", Aws::Utils::Logging::LogLevel::Error},
{"warn", Aws::Utils::Logging::LogLevel::Warn},
{"info", Aws::Utils::Logging::LogLevel::Info},
{"debug", Aws::Utils::Logging::LogLevel::Debug},
{"trace", Aws::Utils::Logging::LogLevel::Trace}};
static const std::map<const int, const Aws::Utils::Logging::LogLevel>
log_levels_tf_to_aws = {{0, Aws::Utils::Logging::LogLevel::Info},
{1, Aws::Utils::Logging::LogLevel::Warn},
{2, Aws::Utils::Logging::LogLevel::Error},
{3, Aws::Utils::Logging::LogLevel::Fatal}};
namespace tf_s3_filesystem {
AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level)
: log_level_(log_level) {}
void AWSLogSystem::LogMessage(Aws::Utils::Logging::LogLevel log_level,
const std::string& message) {
if (message == "Initializing Curl library") return;
switch (log_level) {
case Aws::Utils::Logging::LogLevel::Info:
TF_Log(TF_INFO, message.c_str());
break;
case Aws::Utils::Logging::LogLevel::Warn:
TF_Log(TF_WARNING, message.c_str());
break;
case Aws::Utils::Logging::LogLevel::Error:
TF_Log(TF_ERROR, message.c_str());
break;
case Aws::Utils::Logging::LogLevel::Fatal:
TF_Log(TF_FATAL, message.c_str());
break;
default:
// this will match for DEBUG, TRACE
TF_Log(TF_INFO, message.c_str());
break;
}
}
void AWSLogSystem::Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
const char* format, ...) {
char buffer[256];
va_list args;
va_start(args, format);
vsnprintf(buffer, 256, format, args);
va_end(args);
LogMessage(log_level, buffer);
}
void AWSLogSystem::LogStream(Aws::Utils::Logging::LogLevel log_level,
const char* tag,
const Aws::OStringStream& message_stream) {
LogMessage(log_level, message_stream.rdbuf()->str().c_str());
}
void AWSLogSystem::Flush() { return; }
static Aws::Utils::Logging::LogLevel TfLogLevelToAwsLogLevel(int level) {
// Converts TF Log Levels INFO, WARNING, ERROR and FATAL to the AWS enum
// values for the levels
if (log_levels_tf_to_aws.find(level) != log_levels_tf_to_aws.end()) {
return log_levels_tf_to_aws.at(level);
} else {
// default to fatal
return Aws::Utils::Logging::LogLevel::Fatal;
}
}
static Aws::Utils::Logging::LogLevel ParseAwsLogLevelFromEnv() {
// defaults to FATAL log level for the AWS SDK
// this is because many normal tensorflow operations are logged as errors in
// the AWS SDK such as checking if a file exists can log an error in AWS SDK
// if the file does not actually exist. Another such case is when reading a
// file till the end, TensorFlow expects to see an InvalidRange exception at
// the end, but this would be an error in the AWS SDK. This confuses users,
// hence the default setting.
Aws::Utils::Logging::LogLevel log_level =
Aws::Utils::Logging::LogLevel::Fatal;
const char* aws_env_var_val = getenv("AWS_LOG_LEVEL");
if (aws_env_var_val != nullptr) {
std::string maybe_integer_str(aws_env_var_val, strlen(aws_env_var_val));
std::istringstream ss(maybe_integer_str);
int level;
ss >> level;
if (ss.fail()) {
// wasn't a number
// expecting a string
std::string level_str = maybe_integer_str;
if (log_levels_string_to_aws.find(level_str) !=
log_levels_string_to_aws.end()) {
log_level = log_levels_string_to_aws.at(level_str);
}
} else {
// backwards compatibility
// valid number, but this number follows the standard TensorFlow log
// levels need to convert this to AWS SDK logging level number
log_level = TfLogLevelToAwsLogLevel(level);
}
}
return log_level;
}
static bool initialized = false;
ABSL_CONST_INIT static absl::Mutex s3_logging_mutex(absl::kConstInit);
void AWSLogSystem::InitializeAWSLogging() {
absl::MutexLock l(&s3_logging_mutex);
if (!initialized) {
Aws::Utils::Logging::InitializeAWSLogging(Aws::MakeShared<AWSLogSystem>(
kAWSLoggingTag, ParseAwsLogLevelFromEnv()));
initialized = true;
return;
}
}
void AWSLogSystem::ShutdownAWSLogging() {
absl::MutexLock l(&s3_logging_mutex);
if (initialized) {
Aws::Utils::Logging::ShutdownAWSLogging();
initialized = false;
return;
}
}
} // namespace tf_s3_filesystem

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_
#include <aws/core/utils/logging/LogLevel.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <atomic>
#include <string>
namespace tf_s3_filesystem {
class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface {
public:
static void InitializeAWSLogging();
static void ShutdownAWSLogging();
explicit AWSLogSystem(Aws::Utils::Logging::LogLevel log_level);
virtual ~AWSLogSystem() = default;
// Gets the currently configured log level.
Aws::Utils::Logging::LogLevel GetLogLevel(void) const override {
return log_level_;
}
// Set a new log level. This has the immediate effect of changing the log.
void SetLogLevel(Aws::Utils::Logging::LogLevel log_level) {
log_level_.store(log_level);
}
// Does a printf style output to ProcessFormattedStatement. Don't use this,
// it's unsafe. See LogStream.
void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
const char* format, ...) override;
// Writes the stream to ProcessFormattedStatement.
void LogStream(Aws::Utils::Logging::LogLevel log_level, const char* tag,
const Aws::OStringStream& messageStream) override;
// Flushes the buffered messages if the logger supports buffering
void Flush() override;
private:
void LogMessage(Aws::Utils::Logging::LogLevel log_level,
const std::string& message);
std::atomic<Aws::Utils::Logging::LogLevel> log_level_;
};
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_

View File

@ -38,6 +38,8 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for S3 environments.
@ -186,6 +188,8 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
absl::MutexLock l(&s3_file->initialization_lock);
if (s3_file->s3_client.get() == nullptr) {
tf_s3_filesystem::AWSLogSystem::InitializeAWSLogging();
Aws::SDKOptions options;
options.cryptoOptions.sha256Factory_create_fn = []() {
return Aws::MakeShared<tf_s3_filesystem::AWSSHA256Factory>(
@ -250,6 +254,7 @@ static void ShutdownClient(Aws::S3::S3Client* s3_client) {
delete s3_client;
Aws::SDKOptions options;
Aws::ShutdownAPI(options);
tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging();
}
}
@ -281,6 +286,7 @@ void Cleanup(TF_RandomAccessFile* file) {
static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
TF_VLog(3, "ReadFile using S3Client\n");
Aws::S3::Model::GetObjectRequest get_object_request;
get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object);
Aws::String bytes =
@ -306,12 +312,14 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
TF_VLog(3, "Using TransferManager\n");
auto create_download_stream = [&]() {
return Aws::New<TFS3UnderlyingStream>(
"S3ReadStream",
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
"S3ReadStream", reinterpret_cast<unsigned char*>(buffer), n));
};
TF_VLog(3, "Created stream to read with transferManager\n");
auto handle = s3_file->transfer_manager->DownloadFile(
s3_file->bucket, s3_file->object, offset, n, create_download_stream);
handle->WaitUntilFinished();
@ -322,6 +330,10 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
retries++ < kDownloadRetries) {
// Only failed parts will be downloaded again.
TF_VLog(
1,
"Retrying read of s3://%s/%s after failure. Current retry count: %u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
s3_file->transfer_manager->RetryDownload(handle);
handle->WaitUntilFinished();
}
@ -341,6 +353,8 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
TF_VLog(1, "ReadFilefromS3 s3://%s/%s from %u for n: %u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), offset, n);
if (s3_file->use_multi_part_download)
return ReadS3TransferManager(s3_file, offset, n, buffer, status);
else
@ -416,6 +430,8 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
return;
}
TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(),
s3_file->object.c_str());
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
auto handle = s3_file->transfer_manager->UploadFile(
s3_file->outfile, s3_file->bucket, s3_file->object,
@ -426,6 +442,10 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
retries++ < kUploadRetries) {
// if multipart upload was used, only the failed parts will be re-sent
TF_VLog(1,
"Retrying upload of s3://%s/%s after failure. Current retry count: "
"%u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
handle->WaitUntilFinished();
}
@ -613,6 +633,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
TF_VLog(1, "Stat on path: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -737,6 +758,8 @@ static void SimpleCopyFile(const Aws::String& source,
const Aws::String& bucket_dst,
const Aws::String& object_dst, S3File* s3_file,
TF_Status* status) {
TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(),
object_dst.c_str());
Aws::S3::Model::CopyObjectRequest copy_object_request;
copy_object_request.WithCopySource(source)
.WithBucket(bucket_dst)
@ -801,6 +824,8 @@ static void MultiPartCopy(const Aws::String& source,
const Aws::String& object_dst, const size_t num_parts,
const uint64_t file_size, S3File* s3_file,
TF_Status* status) {
TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(),
object_dst.c_str());
Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);
@ -827,6 +852,8 @@ static void MultiPartCopy(const Aws::String& source,
auto chunk_size =
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(),
num_parts, chunk_size);
size_t retries = 0;
while (retries++ < 3) {
// Queue up parts.
@ -891,6 +918,9 @@ static void MultiPartCopy(const Aws::String& source,
status);
} else {
// Retry.
TF_Log(TF_ERROR,
"Retrying failed copy of part %u due to an error with S3\n",
part_number);
num_finished_parts--;
}
}
@ -967,6 +997,7 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "DeleteFile: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -985,6 +1016,7 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "CreateDir: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1026,6 +1058,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "DeleteDir: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1060,6 +1093,7 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
TF_VLog(1, "RenameFile from: %s to %s\n", src, dst);
Aws::String bucket_src, object_src;
ParseS3Path(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1120,6 +1154,7 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
TF_VLog(1, "GetChildren for path: %s\n", path);
Aws::String bucket, prefix;
ParseS3Path(path, true, &bucket, &prefix, status);
if (TF_GetCode(status) != TF_OK) return -1;

View File

@ -3,6 +3,24 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "array_grad",
srcs = ["array_grad.cc"],
hdrs = [
"array_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library(
name = "math_grad",
srcs = ["math_grad.cc"],
@ -13,11 +31,63 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
],
)
cc_library(
name = "nn_grad",
srcs = ["nn_grad.cc"],
hdrs = [
"nn_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "gradients",
hdrs = [
"array_grad.h",
"math_grad.h",
"nn_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_grad",
":math_grad",
":nn_grad",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"array_grad.h",
"math_grad.h",
"nn_grad.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/array_grad.h"
namespace tensorflow {
namespace gradients {
namespace {
using std::vector;
class IdentityNGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(grad_inputs.size(), nullptr);
for (int i = 0; i < grad_inputs.size(); i++) {
auto grad_input = grad_inputs[i];
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
// that takes care of this similar to `Tensor`?
if (grad_input) {
grad_input->Ref();
}
(*grad_outputs)[i] = grad_input;
}
return Status::OK();
}
~IdentityNGradientFunction() override {}
};
} // namespace
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
auto gradient_function = new IdentityNGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_

View File

@ -14,9 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
using tensorflow::ops::Identity;
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::MatMul;
using tensorflow::ops::Mul;
namespace tensorflow {
namespace gradients {
@ -24,30 +31,184 @@ namespace {
class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity1"));
(*grad_outputs)[1] = identity_outputs[0];
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
(*grad_outputs)[1] = grad_inputs[0];
(*grad_outputs)[0]->Ref();
(*grad_outputs)[1]->Ref();
return Status::OK();
}
~AddGradientFunction() override {}
};
class ExpGradientFunction : public GradientFunction {
public:
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
exp->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
vector<AbstractTensorHandle*> conj_outputs(1);
std::string name = "Conj_Exp_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
grad_outputs->resize(1);
name = "Mul_Exp_Grad";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~ExpGradientFunction() override {}
private:
AbstractTensorHandlePtr exp_;
};
class MatMulGradientFunction : public GradientFunction {
public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
AttrBuilder f_attrs)
: forward_inputs(f_inputs), forward_attrs(f_attrs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a matmul op A*B, the gradients are:
*
* dA = U * B.T
* dB = A.T * U
*
* where A.T means `transpose(A)`
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
grad_outputs->resize(2);
// Get transpose attrs
bool t_a;
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_a", &t_a));
bool t_b;
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_b", &t_b));
// Conj each input
vector<AbstractTensorHandle*> conj_outputs(1);
std::string name = "Conj_A_MatMul_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandle* A = conj_outputs[0];
name = "Conj_B_MatMul_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandle* B = conj_outputs[0];
// Calc Grad
vector<AbstractTensorHandle*> matmul_A_outputs(1);
vector<AbstractTensorHandle*> matmul_B_outputs(1);
std::string name_grad_A = "MatMul_Grad_A";
std::string name_grad_B = "MatMul_Grad_B";
if (!t_a && !t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ false));
} else if (!t_a && t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ false));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ false));
} else if (t_a && !t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ false));
} else { // t_a && t_b
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ true));
}
// Gradient for A
(*grad_outputs)[0] = matmul_A_outputs[0];
// Gradient for B
(*grad_outputs)[1] = matmul_B_outputs[0];
return Status::OK();
}
~MatMulGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_inputs;
AttrBuilder forward_attrs;
};
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
auto gradient_function = new AddGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -19,8 +19,10 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
using std::vector;
using tensorflow::ops::Mul;
using tensorflow::ops::ReluGrad;
namespace tensorflow {
namespace gradients {
namespace {
class ReluGradientFunction : public GradientFunction {
public:
explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
: forward_outputs(f_outputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* activations = forward_outputs[0];
grad_outputs->resize(1);
vector<AbstractTensorHandle*> relugrad_outputs(1);
// Calculate Grad
std::string name = "relu_grad";
TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations},
absl::MakeSpan(relugrad_outputs),
name.c_str()));
(*grad_outputs)[0] = relugrad_outputs[0];
return Status::OK();
}
~ReluGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_outputs;
};
Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec,
AbstractTensorHandle* mat,
absl::Span<AbstractTensorHandle*> outputs) {
if (!isa<ImmediateExecutionContext>(ctx)) {
// TODO(b/168850692): Fix this.
return errors::Unimplemented(
"BroadcastMul is not supported in tracing mode yet.");
}
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(ctx);
AbstractTensorPtr minus_1(imm_ctx->CreateInt32Scalar(-1));
ImmediateTensorHandlePtr dim(imm_ctx->CreateLocalHandle(minus_1.get()));
vector<AbstractTensorHandle*> expand_dims_outputs(1);
TF_RETURN_IF_ERROR(ops::ExpandDims(ctx, {vec, dim.get()},
absl::MakeSpan(expand_dims_outputs),
"ExpandDims"));
TF_RETURN_IF_ERROR(
ops::Mul(ctx, {expand_dims_outputs[0], mat}, outputs, "Mul"));
expand_dims_outputs[0]->Unref();
return Status::OK();
}
class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
: public GradientFunction {
public:
explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction(
vector<AbstractTensorHandle*> f_outputs)
: forward_outputs(f_outputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
// Grad for Softmax Input
vector<AbstractTensorHandle*> mul_outputs(1);
TF_RETURN_IF_ERROR(BroadcastMul(
ctx->ctx, grad_inputs[0], forward_outputs[1],
absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad
(*grad_outputs)[0] = mul_outputs[0];
// Grad for labels is null
(*grad_outputs)[1] = nullptr;
return Status::OK();
}
~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_outputs;
};
} // namespace
BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
auto gradient_function = new ReluGradientFunction(op.outputs);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op) {
auto gradient_function =
new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_

View File

@ -15,6 +15,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
@ -22,3 +23,80 @@ cc_library(
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "math_ops",
srcs = [
"math_ops.cc",
],
hdrs = [
"math_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core:framework",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "nn_ops",
srcs = [
"nn_ops.cc",
],
hdrs = [
"nn_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "ops",
hdrs = [
"array_ops.h",
"math_ops.h",
"nn_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_ops",
":math_ops",
":nn_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"array_ops.h",
"math_ops.h",
"nn_ops.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -14,11 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
@ -34,5 +35,51 @@ Status Identity(AbstractContext* ctx,
return identity_op->Execute(outputs, &num_retvals);
}
Status ZerosLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr z_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(z_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(z_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0]));
int num_retvals = 1;
return z_op->Execute(outputs, &num_retvals);
}
Status Shape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr shape_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(shape_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(shape_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input
int num_retvals = 1;
TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status ExpandDims(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("ExpandDims", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(op->AddInput(inputs[1]));
int num_retvals = 1;
return op->Execute(outputs, &num_retvals);
}
} // namespace ops
} // namespace tensorflow

View File

@ -15,16 +15,30 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace ops {
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status ZerosLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Shape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status ExpandDims(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,166 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
using tensorflow::tracing::TracingOperation;
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
int num_retvals = 1;
return mul_op->Execute(outputs, &num_retvals);
}
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
auto dtype = inputs[0]->DataType();
if (DataTypeIsFloating(BaseType(dtype)) ||
DataTypeIsInteger(BaseType(dtype))) {
TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
} else {
return errors::Unimplemented("Conj does not support complex types yet.");
}
return Status::OK();
}
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr add_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sub_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(sub_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(sub_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a = false, bool transpose_b = false) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a));
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b));
int num_retvals = 1;
TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr neg_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(neg_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(neg_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
int num_retvals = 1;
return neg_op->Execute(outputs, &num_retvals);
}
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sum_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(sum_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(sum_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices
int num_retvals = 1;
TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status DivNoNan(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr div_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(div_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(div_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
int num_retvals = 1;
TF_RETURN_IF_ERROR(div_op->Execute(
outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0)
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
namespace tensorflow {
namespace ops {
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b);
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status DivNoNan(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
// Softmax Loss given scores and labels, used by the SoftMaxLossGradient
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sm_loss_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(sm_loss_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(sm_loss_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels
// Outputs will contain: [loss_vals, gradients].
int num_retvals = 2;
TF_RETURN_IF_ERROR(sm_loss_op->Execute(outputs, &num_retvals));
return Status::OK();
}
// Computes Relu gradient given input features
Status ReluGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr relugrad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(relugrad_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(relugrad_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs
int num_retvals = 1;
TF_RETURN_IF_ERROR(relugrad_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status Relu(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(relu_op->Reset("Relu", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(relu_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(relu_op->Execute(outputs, &num_retvals));
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,41 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace ops {
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status ReluGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Relu(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_

View File

@ -44,7 +44,9 @@ cc_library(
],
deps = [
":concrete_function",
":signature_def_function",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
@ -70,6 +72,26 @@ cc_library(
],
)
cc_library(
name = "signature_def_function",
hdrs = [
"signature_def_function.h",
],
deps = [
":signature_def_function_metadata",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "signature_def_function_metadata",
hdrs = [
"signature_def_function_metadata.h",
],
)
cc_library(
name = "test_utils",
testonly = True,
@ -115,6 +137,7 @@ cc_library(
":concrete_function",
":saved_model_api",
":saved_model_utils",
":signature_def_function",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
@ -206,13 +229,30 @@ tf_cc_test(
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
tf_cc_test(
name = "signature_flattening_test",
srcs = [
"signature_flattening_test.cc",
],
deps = [
":saved_model_utils",
"//tensorflow/c/experimental/saved_model/core:tf_concrete_function_test_protos",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -26,10 +26,14 @@ limitations under the License.
namespace tensorflow {
// Note that ConcreteFunctions's lifetimes are effectively bound
// to the SavedModel they are loaded from, since they retain pointers
// to the TensorHandles owned by the SavedModel, and the FunctionDef
// of the SavedModel.
// ConcreteFunctions correspond to an instance of a tf.function with a known set
// of inputs (either through get_concrete_function) or an input_signature.
// ConcreteFunction attempts to preserve the user-facing semantics of the
// tf.function python API and can take a limited set of types as arguments
// (to be modeled in tensorflow::Value), not just Tensors.
// SavedModelAPI's ConcreteFunctions' lifetimes are bound to the SavedModel they
// are loaded from, since they retain pointers to the TensorHandles owned by the
// SavedModel, and the FunctionDef of the SavedModel.
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection
@ -39,8 +43,8 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = default;
// This method returns the "Call" Op used to execute the function.
virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) = 0;
virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const = 0;
virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
};

View File

@ -37,10 +37,11 @@ static const char kNoSharingResourceID[] =
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
const char* raw_device_name,
ImmediateTensorHandlePtr* handle) {
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", raw_device_name));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and

View File

@ -31,6 +31,7 @@ namespace internal {
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
const char* raw_device_name,
ImmediateTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated

View File

@ -55,7 +55,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
context(), DT_FLOAT, {}, nullptr, &handle));
// The created TensorHandle should be a DT_Resource
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
}
@ -65,7 +65,7 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
context(), DT_FLOAT, {}, nullptr, &handle));
// Destroy the variable
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
@ -76,7 +76,7 @@ TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
context(), DT_FLOAT, {}, nullptr, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.

View File

@ -28,6 +28,26 @@ cc_library(
],
)
cc_library(
name = "flat_tensor_function",
srcs = [
"flat_tensor_function.cc",
],
hdrs = [
"flat_tensor_function.h",
],
deps = [
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "variable",
srcs = [
@ -68,7 +88,7 @@ cc_library(
"tf_concrete_function.h",
],
deps = [
":tensorhandle_convertible",
":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
@ -81,3 +101,26 @@ cc_library(
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_signature_def_function",
srcs = [
"tf_signature_def_function.cc",
],
hdrs = [
"tf_signature_def_function.h",
],
deps = [
":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include <memory>
#include <string>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
FlatTensorFunction::FlatTensorFunction(
const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx)
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
FlatTensorFunction::~FlatTensorFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status FlatTensorFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
out->reset(new FlatTensorFunction(function_def->signature().name(),
std::move(captures), ctx));
return Status();
}
Status FlatTensorFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// FlatTensorFunction models a TF2 eager runtime view of a callable function,
// taking + returning flat lists of tensors, including any captures.
// Effectively, it is a thin wrapper around a FunctionDef owned by the
// EagerContext, and any TensorHandle captures associated with the function. The
// MakeCallOp method handles the logic of marshaling captures after the user
// provided inputs automatically.
// Note(bmzhao): This class is mainly intended to house low-level reusable
// function logic between SignatureDefFunction and ConcreteFunction, which
// present higher level interfaces. This type does *not* hold any "function
// metadata".
class FlatTensorFunction {
public:
// Factory for creating a FlatTensorFunction.
//
// Params:
// function_def - The function_def associated with the created
// FlatTensorFunction. FlatTensorFunction will register this
// function_def with `ctx` on creation, and de-register it on
// destruction. function_def must be non-null, but
// otherwise has no lifetime requirements.
// captures - The captured TensorHandles associated with this
// FlatTensorFunction.
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
// outlive TFConcreteFunction.
// out - The output FlatTensorFunction.
static Status Create(const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx,
std::unique_ptr<FlatTensorFunction>* out);
// This method creates a "Call" Op used to execute the function.
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const;
~FlatTensorFunction();
private:
FlatTensorFunction(const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx);
FlatTensorFunction(const FlatTensorFunction&) = delete;
FlatTensorFunction& operator=(const FlatTensorFunction&) = delete;
// Name of the FunctionDef corresponding to this TFConcreteFunction
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
@ -33,32 +33,20 @@ limitations under the License.
namespace tensorflow {
TFConcreteFunction::TFConcreteFunction(
const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx)
: name_(name),
captures_(std::move(captures)),
metadata_(std::move(metadata)),
ctx_(ctx) {}
TFConcreteFunction::~TFConcreteFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
TFConcreteFunction::TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
FunctionMetadata metadata)
: func_(std::move(func)), metadata_(std::move(metadata)) {}
Status TFConcreteFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx,
std::unique_ptr<TFConcreteFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
out->reset(new TFConcreteFunction(function_def->signature().name(),
std::move(captures), std::move(metadata),
ctx));
std::unique_ptr<FlatTensorFunction> func;
TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
function_def, std::move(captures), ctx, &func));
out->reset(new TFConcreteFunction(std::move(func), std::move(metadata)));
return Status();
}
@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
return metadata_;
}
Status TFConcreteFunction::GetCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle**>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
Status TFConcreteFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
return func_->MakeCallOp(inputs, out);
}
} // namespace tensorflow

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction {
std::unique_ptr<TFConcreteFunction>* out);
// This method returns the "Call" Op used to execute the function.
Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) override;
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const override;
const FunctionMetadata& GetFunctionMetadata() const override;
~TFConcreteFunction() override;
~TFConcreteFunction() override = default;
private:
TFConcreteFunction(const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx);
TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
FunctionMetadata metadata);
TFConcreteFunction(const TFConcreteFunction&) = delete;
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
// Name of the FunctionDef corresponding to this TFConcreteFunction
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
std::unique_ptr<FlatTensorFunction> func_;
FunctionMetadata metadata_;
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More