Merge branch 'master' into tanh_logistic_fixes
This commit is contained in:
commit
595c3e3d13
14
.bazelrc
14
.bazelrc
@ -49,7 +49,6 @@
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
# mkl: Enable full mkl support.
|
||||
# tensorrt: Enable Tensorrt support.
|
||||
# ngraph: Enable ngraph support.
|
||||
# numa: Enable numa using hwloc.
|
||||
# noaws: Disable AWS S3 storage support
|
||||
# nogcp: Disable GCS support.
|
||||
@ -159,6 +158,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
||||
# environment variable "TF_MKL_ROOT" every time before build.
|
||||
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_openmp=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
@ -172,8 +172,15 @@ build:mkl_threadpool -c opt
|
||||
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
||||
build:mkl_opensource_only --define=build_with_openmp=true
|
||||
build:mkl_opensource_only -c opt
|
||||
|
||||
# Config setting to build with oneDNN for Arm.
|
||||
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
|
||||
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_aarch64 --define=build_with_mkl_opensource=true
|
||||
build:mkl_aarch64 -c opt
|
||||
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -212,7 +219,6 @@ build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --action_env TF_NEED_ROCM=1
|
||||
|
||||
# Options extracted from configure script
|
||||
build:ngraph --define=with_ngraph_support=true
|
||||
build:numa --define=with_numa_support=true
|
||||
|
||||
# Options to disable default on features
|
||||
@ -277,7 +283,7 @@ build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:linux --host_copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
build:windows --copt=/W0
|
||||
|
||||
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
||||
# _USE_MATH_DEFINES is defined.
|
||||
@ -288,9 +294,11 @@ build:windows --host_copt=/D_USE_MATH_DEFINES
|
||||
build:linux --define=PREFIX=/usr
|
||||
build:linux --define=LIBDIR=$(PREFIX)/lib
|
||||
build:linux --define=INCLUDEDIR=$(PREFIX)/include
|
||||
build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
|
||||
build:macos --define=PREFIX=/usr
|
||||
build:macos --define=LIBDIR=$(PREFIX)/lib
|
||||
build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
|
6
.github/bot_config.yml
vendored
6
.github/bot_config.yml
vendored
@ -12,12 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
|
@ -1,4 +1,3 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -12,13 +11,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
set -e
|
||||
set -x
|
||||
# ============================================================================
|
||||
|
||||
source tensorflow/tools/ci_build/release/common.sh
|
||||
|
||||
# Copy and rename to tensorflow
|
||||
for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*m-win_amd64.whl); do
|
||||
copy_to_new_project_name "${f}" tensorflow_gpu
|
||||
done
|
||||
on:
|
||||
workflow_dispatch: # Allow manual triggers
|
||||
schedule:
|
||||
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
|
||||
name: Set nightly branch to master HEAD
|
||||
jobs:
|
||||
master-to-nightly:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: zofrex/mirror-branch@v1
|
||||
name: Set nightly branch to master HEAD
|
||||
with:
|
||||
target-branch: 'nightly'
|
@ -4,7 +4,7 @@
|
||||
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
|
||||
/tenosrflow/core/debug @caisq
|
||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||
/tensorflow/core/platform/windows/ @gunan @mihaimaruseac
|
||||
/tensorflow/core/platform/windows/ @mihaimaruseac
|
||||
/tensorflow/lite/experimental/micro @petewarden @advaitjain
|
||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||
/tensorflow/python/debug @caisq
|
||||
|
36
README.md
36
README.md
@ -103,23 +103,22 @@ open-source software development:
|
||||
|
||||
### Official Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
@ -145,12 +144,13 @@ Build Type
|
||||
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
|
||||
* [TensorFlow Examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
|
||||
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the
|
||||
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
|
700
RELEASE.md
700
RELEASE.md
@ -1,3 +1,54 @@
|
||||
# Release 2.5.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
|
||||
## Known Caveats
|
||||
|
||||
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES).>
|
||||
* <ADDING/BUMPING DEPENDENCIES SHOULD GO HERE>
|
||||
* <KNWON LACK OF SUPPORT ON SOME PLATFORM, SHOULD GO HERE>
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
|
||||
* TPU embedding support
|
||||
* Added `profile_data_directory` to `EmbeddingConfigSpec` in
|
||||
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
||||
gathered at runtime to be used in embedding layer partitioning decisions.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* `tf.keras`:
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* Discretization combiner implemented, with additional arg `epsilon`.
|
||||
|
||||
* `tf.data`:
|
||||
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
|
||||
to control how external state should be handled during dataset
|
||||
serialization or iterator checkpointing.
|
||||
|
||||
* `tf.lite`:
|
||||
* NNAPI
|
||||
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
directly.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
|
||||
# Release 2.4.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
@ -6,6 +57,15 @@
|
||||
|
||||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
|
||||
matmuls and convolutions, due to the use of
|
||||
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
|
||||
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
|
||||
bits of precision. This is unlikely to cause issues in practice for deep
|
||||
learning models. In some cases, TensorFloat-32 is also used for complex64 ops.
|
||||
TensorFloat-32 can be disabled by running
|
||||
`config.experimental.enable_tensor_float_32_execution(False)`. The "Major
|
||||
Features and Improvements" section has more details.
|
||||
* The byte layout for string tensors across the C-API has been updated to match
|
||||
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
|
||||
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
|
||||
@ -34,6 +94,7 @@
|
||||
shape assumptions (note that you can pass shapes with `None` entries for axes
|
||||
that are meant to be dynamic). You can also disable the input checking
|
||||
entirely by setting `model.input_spec = None`.
|
||||
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
|
||||
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
||||
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
||||
removed).
|
||||
@ -46,6 +107,49 @@
|
||||
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
||||
instead of individual arguments. Usages should be updated to
|
||||
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
||||
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
|
||||
updates the gradient definition for quantization which is outside the range
|
||||
to be 0. To simulate the V1 the behavior of
|
||||
tf.quantization.quantize_and_dequantize(...) use
|
||||
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
|
||||
use `tf.data.Dataset.from_tensor_slices` instead.
|
||||
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
|
||||
`tf.distribute.StrategyExtended.batch_reduce_to`,
|
||||
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
|
||||
`tf.distribute.experimental.CollectiveHints` is renamed
|
||||
`tf.distribute.experimental.CommunicationOptions`.
|
||||
`tf.distribute.experimental.CollectiveCommunication` is renamed
|
||||
`tf.distribute.experimental.CommunicationImplementation`.
|
||||
* `tf.keras.mixed_precision.experimental`:
|
||||
* `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
|
||||
dtype it will be casted to.
|
||||
* When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
|
||||
float16 or bfloat16 tensor instead of a float32 tensor.
|
||||
* The property
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now
|
||||
a tensor, not a `LossScale` object. This means to get a loss scale of a
|
||||
`LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead
|
||||
of `opt.loss_scale()`.
|
||||
* The property `should_cast_variables` has been removed from
|
||||
`tf.keras.mixed_precision.experimental.Policy`
|
||||
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the
|
||||
`DynamicLossScale`'s multiplier must be 2.
|
||||
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
|
||||
the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of
|
||||
being reused. This means modifying the weights of the `DynamicLossScale`
|
||||
will no longer affect the weights of the LossScaleOptimizer, and vice versa.
|
||||
* The global policy can no longer be set to a non-floating point policy in
|
||||
`tf.keras.mixed_precision.experimental.set_policy`
|
||||
* In `Layer.call`, `AutoCastVariable`s will no longer be casted within
|
||||
`MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a
|
||||
thread local variable is used to determine whether `AutoCastVariable`s are
|
||||
casted, and those two functions run with a different thread. Note this only
|
||||
applies if one of these two functions is called within `Layer.call`; if one
|
||||
of those two functions calls `Layer.call`, `AutoCastVariable`s will still be
|
||||
casted.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
@ -57,141 +161,231 @@
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
|
||||
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
|
||||
* Support for
|
||||
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
|
||||
on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a
|
||||
math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as
|
||||
matrix multiplications and convolutions, to run much faster on Ampere GPUs but
|
||||
with reduced precision. This reduced precision has not been found to effect
|
||||
convergence quality of deep learning models in practice. TensorFloat-32 is
|
||||
enabled by default, but can be disabled with
|
||||
`tf.config.experimental.enable_tensor_float_32_execution`.
|
||||
|
||||
* `tf.distribute`:
|
||||
* `MultiWorkerMirroredStrategy` is graduated out of experimental.
|
||||
* Peer failure will no longer cause the cluster to hang.
|
||||
* Major issues with saving are fixed.
|
||||
* See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
|
||||
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
|
||||
* The `tf.keras.mixed_precision` API has been made non-experimental. The major
|
||||
changes to the new non-experimental API are:
|
||||
* `tf.keras.mixed_precision.Policy` no longer takes in a
|
||||
`tf.mixed_precision.experimental.LossScale` in the constructor, and no
|
||||
longer has a `LossScale` associated with it. Instead, `Model.compile` will
|
||||
automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic
|
||||
loss scaling if `Policy.name` is "mixed_float16".
|
||||
* `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in
|
||||
different arguments. In particular, it no longer takes in a `LossScale`, and
|
||||
there is no longer a `LossScale` associated with the `LossScaleOptimizer`.
|
||||
Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss
|
||||
scaling. See the documentation of
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on
|
||||
the differences between the experimental `LossScaleOptimizer` and the new
|
||||
non-experimental `LossScaleOptimizer`.
|
||||
* `tf.mixed_precision.experimental.LossScale` and its subclasses are
|
||||
deprecated, as all of its functionality now exists within
|
||||
`tf.keras.mixed_precision.LossScaleOptimizer`
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* Security:
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
|
||||
`SparseCountSparseOutput` operations
|
||||
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
|
||||
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
|
||||
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
|
||||
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
||||
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
||||
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Fixes several vulnerabilities in TFLite implementation of segment sum
|
||||
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||
* TF Core:
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
|
||||
type annotation for variables representing a Tensor or a value that can be
|
||||
converted to Tensor by `tf.convert_to_tensor`.
|
||||
* Calling ops with a python constants or numpy values is now consistent with
|
||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
||||
truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
||||
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
|
||||
and `__invert__` now support non-`bool` arguments and apply the
|
||||
corresponding bitwise ops. `bool` arguments continue to be supported and
|
||||
dispatch to logical ops. This brings them more in line with Python and NumPy
|
||||
benavior.
|
||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
|
||||
the same sparsity pattern, but with new provided values. It is similar to
|
||||
the `with_values` function of `RaggedTensor`.
|
||||
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
|
||||
* Added `tf.config.experimental.get_memory_usage` to return total memory usage
|
||||
of the device.
|
||||
* `tf.data`:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
|
||||
to register a dataset with the tf.data service, and another process to
|
||||
consume data from the dataset.
|
||||
* Added support for tf.data service dispatcher fault tolerance. To enable
|
||||
fault tolerance, configure a `work_dir` when running your dispatcher
|
||||
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
|
||||
store its state to `work_dir`, so that on restart it can continue from its
|
||||
previous state after restart.
|
||||
* Added tf.data service support for sharing dataset graphs via shared
|
||||
filesystem instead of over RPC. This reduces load on the dispatcher,
|
||||
improving performance of distributing datasets. For this to work, the
|
||||
dispatcher's `work_dir` must be accessible from workers. If the worker
|
||||
fails to read from the `work_dir`, it falls back to using RPC for dataset
|
||||
graph transfer.
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
* We have implemented an optimization which reorders data-discarding
|
||||
transformations such as `take` and `shard` to happen earlier in the
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
||||
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
||||
with a new `output_signature` argument, which allows `from_generator` to
|
||||
produce any type describable by a `tf.TypeSpec`.
|
||||
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
|
||||
`tf.data.AUTOTUNE`.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
`stateless_sample_distorted_bounding_box` which is a determinstic
|
||||
version of `sample_distorted_bounding_box` op. Given the same seed, these
|
||||
stateless functions/ops produce the same results independent of how many
|
||||
times the function is called, and independent of global seed settings.
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* Security:
|
||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||
* Fixes three vulnerabilities in conversion to DLPack format
|
||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
|
||||
`SparseCountSparseOutput` operations
|
||||
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
|
||||
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
|
||||
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
|
||||
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
||||
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
||||
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
||||
* Fixes an integer truncation vulnerability in code using the work sharder
|
||||
API
|
||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||
* Fixes segfault raised by calling session-only ops in eager mode
|
||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||
* Fixes data leak and potential ASLR violation from
|
||||
`tf.raw_ops.StringNGrams`
|
||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||
* Fixes a data corruption due to a bug in negative indexing support in
|
||||
TFLite
|
||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||
* Fixes several vulnerabilities in TFLite saved model format
|
||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||
* Fixes several vulnerabilities in TFLite implementation of segment sum
|
||||
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||
* Fixes a segfault in `tf.quantization.quantize_and_dequantize`
|
||||
([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
|
||||
* Fixes an undefined behavior float cast causing a crash
|
||||
([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
|
||||
* TF Core:
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be
|
||||
used as type annotation for variables representing a Tensor or a value
|
||||
that can be converted to Tensor by `tf.convert_to_tensor`.
|
||||
* Calling ops with a python constants or numpy values is now consistent
|
||||
with tf.convert_to_tensor behavior. This avoids operations like
|
||||
tf.reshape truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of
|
||||
`SparseTensor` arguments.
|
||||
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
|
||||
`__xor__` and `__invert__` now support non-`bool` arguments and apply
|
||||
the corresponding bitwise ops. `bool` arguments continue to be supported
|
||||
and dispatch to logical ops. This brings them more in line with Python
|
||||
and NumPy behavior.
|
||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
|
||||
with the same sparsity pattern, but with new provided values. It is
|
||||
similar to the `with_values` function of `RaggedTensor`.
|
||||
* Added `StatelessCase` op, and uses it if none of case branches has
|
||||
stateful ops.
|
||||
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||
usage of the device.
|
||||
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
||||
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
|
||||
* `tf.data`:
|
||||
* tf.data service:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one
|
||||
process to register a dataset with the tf.data service, and another
|
||||
process to consume data from the dataset.
|
||||
* Added support for dispatcher fault tolerance. To enable fault tolerance,
|
||||
configure a `work_dir` when running your dispatcher server and set
|
||||
`dispatcher_fault_tolerance=True`. The dispatcher will store its state
|
||||
to `work_dir`, so that on restart it can continue from its previous
|
||||
state after restart.
|
||||
* Added support for sharing dataset graphs via shared filesystem instead
|
||||
of over RPC. This reduces load on the dispatcher, improving performance
|
||||
of distributing datasets. For this to work, the dispatcher's `work_dir`
|
||||
must be accessible from workers. If the worker fails to read from the
|
||||
`work_dir`, it falls back to using RPC for dataset graph transfer.
|
||||
* Added support for a new "distributed_epoch" processing mode. This
|
||||
processing mode distributes a dataset across all tf.data workers,
|
||||
instead of having each worker process the full dataset. See
|
||||
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
|
||||
to learn more.
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be
|
||||
specified.
|
||||
* We have implemented an optimization which reorders data-discarding
|
||||
transformations such as `take` and `shard` to happen earlier in the
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overridden.
|
||||
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
||||
with a new `output_signature` argument, which allows `from_generator` to
|
||||
produce any type describable by a `tf.TypeSpec`.
|
||||
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
|
||||
`tf.data.AUTOTUNE`.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
`stateless_sample_distorted_bounding_box` which is a deterministic
|
||||
version of `sample_distorted_bounding_box` op. Given the same seed,
|
||||
these stateless functions/ops produce the same results independent of
|
||||
how many times the function is called, and independent of global seed
|
||||
settings.
|
||||
* `tf.distribute`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.keras`:
|
||||
* Improvements from the functional API refactoring:
|
||||
* Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models.
|
||||
* Functional model construction should be ~8-10% faster on average.
|
||||
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument.
|
||||
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale`
|
||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||
as an alternative to accepting a `callable` loss.
|
||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
||||
* Added `mobilenet_v3` to keras application model.
|
||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
`gradients_transformers` to allow for custom gradient transformations
|
||||
(such as gradient clipping).
|
||||
* The `steps_per_execution` argument in `compile()` is no longer
|
||||
experimental; if you were passing `experimental_steps_per_execution`,
|
||||
rename it to `steps_per_execution` in your code. This argument controls
|
||||
the number of batches to run during each `tf.function` call when calling
|
||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||
greatly improve performance on TPUs or small models with a large Python
|
||||
overhead.
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
performance.
|
||||
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
||||
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
||||
the values of these symbols at an iteration does not depend on the previous
|
||||
iteration. These types of loops must run at least one iteration, and will
|
||||
raise a runtime error otherwise.
|
||||
* (Experimental) Parameter server training:
|
||||
* Replaced the existing
|
||||
`tf.distribute.experimental.ParameterServerStrategy` symbol with
|
||||
a new class that is for parameter server training in TF2. Usage with
|
||||
the old symbol, usually with Estimator, should be replaced with
|
||||
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
|
||||
* Added `tf.distribute.experimental.coordinator.*` namespace,
|
||||
including the main API `ClusterCoordinator` for coordinating the
|
||||
training cluster, the related data structure `RemoteValue`
|
||||
and `PerWorkerValue`.
|
||||
* `tf.keras`:
|
||||
* Improvements from the functional API refactoring:
|
||||
* Functional model construction does not need to maintain a global
|
||||
workspace graph, removing memory leaks especially when building many
|
||||
models or very large models.
|
||||
* Functional model construction should be ~8-10% faster on average.
|
||||
* Functional models can now contain non-symbolic values in their call
|
||||
inputs inside of the first positional argument.
|
||||
* Several classes of TF ops that were not reliably converted to Keras
|
||||
layers during functional API construction should now work, e.g.
|
||||
`tf.image.ssim_multiscale`
|
||||
* Error messages when Functional API construction goes wrong (and when
|
||||
ops cannot be converted to Keras layers automatically) should be
|
||||
clearer and easier to understand.
|
||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||
as an alternative to accepting a `callable` loss.
|
||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||
to match FTRL paper
|
||||
(https://research.google.com/pubs/archive/41159.pdf).
|
||||
* Added `mobilenet_v3` to keras application model.
|
||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
`gradients_transformers` to allow for custom gradient transformations
|
||||
(such as gradient clipping).
|
||||
* The `steps_per_execution` argument in `compile()` is no longer
|
||||
experimental; if you were passing `experimental_steps_per_execution`,
|
||||
rename it to `steps_per_execution` in your code. This argument controls
|
||||
the number of batches to run during each `tf.function` call when calling
|
||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||
greatly improve performance on TPUs or small models with a large Python
|
||||
overhead.
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* TextVectorization can now accept a vocabulary list or file as an
|
||||
init arg.
|
||||
* TextVectorization, StringLookup, and IntegerLookup can now accept a
|
||||
vocabulary file via the `set_vocab_from_file` method.
|
||||
* Normalization can now accept mean and variance values as init args.
|
||||
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||
accepts a `return_attention_scores` argument. When set to
|
||||
True, the layer returns the attention scores as an additional output
|
||||
argument.
|
||||
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||
with the same implementation as their `tf.losses` equivalent.
|
||||
* For Keras model, the individual call of `Model.evaluate` uses no cached
|
||||
data for evaluation, while `Model.fit` uses cached data when
|
||||
`validation_data` arg is provided for better performance.
|
||||
* Added a `save_traces` argument to `model.save`/
|
||||
`tf.keras.models.save_model` which determines whether the SavedModel
|
||||
format stores the Keras model/layer call functions. The traced functions
|
||||
allow Keras to revive custom models and layers without the original
|
||||
class definition, but if this isn't required the tracing can be
|
||||
disabled with the added option.
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
performance.
|
||||
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
||||
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
||||
the values of these symbols at an iteration does not depend on the
|
||||
previous iteration. These types of loops must run at least one
|
||||
iteration, and will raise a runtime error otherwise.
|
||||
|
||||
Example:
|
||||
|
||||
@ -200,51 +394,106 @@
|
||||
outputs = train_step(batch)
|
||||
tf.print('final outputs', outputs)
|
||||
```
|
||||
|
||||
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
|
||||
info.
|
||||
|
||||
* `tf.lite`:
|
||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||
string to be joined is empty.
|
||||
* `TFLiteConverter`:
|
||||
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
|
||||
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
|
||||
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
|
||||
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
|
||||
* TFLite Profiler for Android is available. See the detailed
|
||||
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `TFLiteConverter`:
|
||||
* Support optional flags `inference_input_type` and
|
||||
`inference_output_type` for full integer quantized models. This
|
||||
allows users to modify the model input and output type to integer
|
||||
types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
|
||||
(`tf.float32`).
|
||||
* TFLite Profiler for Android is available. See the detailed
|
||||
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
||||
* NNAPI
|
||||
* Added NNAPI Delegation support for requantization use cases by
|
||||
converting the operation into a dequantize-quantize pair.
|
||||
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
|
||||
* Use `Interpreter.Options.setUseNNAPI` instead.
|
||||
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
directly.
|
||||
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
|
||||
* Prefer controlling this via delegate options, e.g.
|
||||
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
|
||||
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||
string to be joined is empty.
|
||||
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Math and Linear Algebra:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
|
||||
|
||||
* TPU Enhancements:
|
||||
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
||||
embeddings. Users of other TensorFlow platforms can implement equivalent
|
||||
behavior by adjusting the `l2` parameter.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
||||
embeddings. Users of other TensorFlow platforms can implement equivalent
|
||||
behavior by adjusting the `l2` parameter.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* XLA Support:
|
||||
* xla.experimental.compile is deprecated, use
|
||||
`tf.function(experimental_compile=True)` instead
|
||||
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
|
||||
(currently 'hlo' and 'optimized_hlo') for given input for given function.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* xla.experimental.compile is deprecated, use
|
||||
`tf.function(experimental_compile=True)` instead
|
||||
* Added `tf.function.experimental_get_compiler_ir` which returns compiler
|
||||
IR (currently 'hlo' and 'optimized_hlo') for given input for given
|
||||
function.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `tf.train.Checkpoint`:
|
||||
* Now accepts a `root` argument in the initialization, which generates a
|
||||
checkpoint with a root object. This allows users to create a `Checkpoint`
|
||||
object that is compatible with Keras `model.save_weights()` and
|
||||
`model.load_weights`. The checkpoint is also compatible with the
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
|
||||
* Now accepts a `root` argument in the initialization, which generates a
|
||||
checkpoint with a root object. This allows users to create a
|
||||
`Checkpoint` object that is compatible with Keras `model.save_weights()`
|
||||
and `model.load_weights`. The checkpoint is also compatible with the
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
|
||||
* `tf.nn`:
|
||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||
|
||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||
|
||||
* `tf.debugging`:
|
||||
|
||||
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
||||
|
||||
* `tf.print`:
|
||||
|
||||
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
|
||||
didn't have the keys sorted, the keys and values were not being printed
|
||||
in accordance with their correct mapping.
|
||||
|
||||
* `TensorRT`
|
||||
|
||||
* We now issue a warning when the `session_config` parameter for the TF1
|
||||
converter is used or the `rewrite_config_template` field in the TF2
|
||||
converter parameter object is used.
|
||||
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more context.
|
||||
<ADD RELEASE NOTES HERE>
|
||||
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more
|
||||
context.
|
||||
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
|
||||
rollout the new MLIR TPU bridge.
|
||||
* Added `tf.experimental.register_filesystem_plugin` to load modular
|
||||
filesystem plugins from Python
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
@ -492,42 +741,87 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
# Release 2.3.0
|
||||
|
||||
## Major Features and Improvements
|
||||
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
|
||||
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
||||
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
||||
|
||||
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
|
||||
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and
|
||||
save resources:
|
||||
|
||||
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
|
||||
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
||||
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
||||
|
||||
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
|
||||
In addition checkout the detailed
|
||||
[guide](https://www.tensorflow.org/guide/data_performance_analysis) for
|
||||
analyzing input pipeline performance with TF Profiler.
|
||||
|
||||
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
|
||||
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
|
||||
is now a stable API and no longer considered experimental for TensorFlow.
|
||||
(earlier `tf.distribute.experimental.TPUStrategy`).
|
||||
|
||||
* TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
||||
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new
|
||||
tools: a memory profiler to visualize your model’s memory usage over time
|
||||
and a [python tracer](https://www.tensorflow.org/guide/profiler#events)
|
||||
which allows you to trace python function calls in your model. Usability
|
||||
improvements include better diagnostic messages and
|
||||
[profile options](https://tensorflow.org/guide/profiler#collect_performance_data)
|
||||
to customize the host and device trace verbosity level.
|
||||
|
||||
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
||||
* Introduces experimental support for Keras Preprocessing Layers API
|
||||
([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly))
|
||||
to handle data preprocessing operations, with support for composite tensor
|
||||
inputs. Please see below for additional details on these layers.
|
||||
|
||||
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
|
||||
* TFLite now properly supports dynamic shapes during conversion and inference.
|
||||
We’ve also added opt-in support on Android and iOS for
|
||||
[XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack),
|
||||
a highly optimized set of CPU kernels, as well as opt-in support for
|
||||
[executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
||||
|
||||
* Libtensorflow packages are available in GCS starting this release. We have
|
||||
also started to
|
||||
[release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
||||
|
||||
* The experimental Python API
|
||||
[`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info)
|
||||
now allows you to instrument a TensorFlow program and dump debugging
|
||||
information to a directory on the file system. The directory can be read and
|
||||
visualized by a new interactive dashboard in TensorBoard 2.3 called
|
||||
[Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which
|
||||
reveals the details of the TensorFlow program including graph structures,
|
||||
history of op executions at the Python (eager) and intra-graph levels, the
|
||||
runtime dtype, shape, and numerical composition of tensors, as well as their
|
||||
code locations.
|
||||
|
||||
## Breaking Changes
|
||||
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
||||
* `tf.data`
|
||||
* Makes the following (breaking) changes to the `tf.data`.
|
||||
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
|
||||
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
|
||||
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
|
||||
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
|
||||
* `tf.keras`
|
||||
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
|
||||
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
||||
|
||||
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
||||
* `tf.data`
|
||||
* Makes the following (breaking) changes to the `tf.data`.
|
||||
* C++ API: - `IteratorBase::RestoreInternal`,
|
||||
`IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState`
|
||||
become pure-virtual and subclasses are now expected to provide an
|
||||
implementation.
|
||||
* The deprecated `DatasetBase::IsStateful` method is removed in favor of
|
||||
`DatasetBase::CheckExternalState`.
|
||||
* Deprecated overrides of `DatasetBase::MakeIterator` and
|
||||
`MakeIteratorFromInputElement` are removed.
|
||||
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and
|
||||
`tensorflow::data::IteratorBase::SaveInput` has been extended with
|
||||
`SerializationContext` argument to enable overriding the default policy
|
||||
for the handling external state during iterator checkpointing. This is
|
||||
not a backwards compatible change and all subclasses of `IteratorBase`
|
||||
*need to be updated* accordingly.
|
||||
* `tf.keras`
|
||||
* Add a new `BackupAndRestore` callback for handling distributed training
|
||||
failures & restarts. Please take a look at this
|
||||
[tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
|
||||
for details on how to use the callback.
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
existing C++ kernel `ExtractGlimpse` does not change either, so saved models
|
||||
using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
||||
|
||||
## Known Caveats
|
||||
* `tf.lite`
|
||||
@ -571,6 +865,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
* Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights.
|
||||
* Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights.
|
||||
* Mutable tables now restore checkpointed values when loaded from SavedModel.
|
||||
* The user object metadata field in the SavedModel proto has been deprecated as part of the updates to Keras SavedModel. Keras was the only consumer of this field prior to the update.
|
||||
* GPU
|
||||
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
|
||||
* Remove environmental variable `TF_USE_CUDNN`.
|
||||
@ -599,6 +894,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
|
||||
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
|
||||
* Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices.
|
||||
|
||||
### `tf.keras`:
|
||||
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
|
||||
@ -1097,7 +1393,7 @@ This release contains contributions from many people at Google, as well as:
|
||||
8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee
|
||||
|
||||
# Release 1.15.0
|
||||
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
|
||||
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
|
||||
|
||||
## Major Features and Improvements
|
||||
* As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
|
||||
@ -1107,7 +1403,7 @@ This enables writing forward compatible code: by explicitly importing either `te
|
||||
* Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow.
|
||||
* Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`.
|
||||
* AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS.
|
||||
* Adds `enable_tensor_equality()`, which switches the behavior such that:
|
||||
* Adds `enable_tensor_equality()`, which switches the behavior such that:
|
||||
* Tensors are no longer hashable.
|
||||
* Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0.
|
||||
|
||||
@ -1263,12 +1559,12 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
|
||||
* TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow.
|
||||
* Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation.
|
||||
Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs.
|
||||
|
||||
|
||||
* `tf.contrib`:
|
||||
* `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely.
|
||||
* Remove `tf.contrib.timeseries` dependency on TF distributions.
|
||||
* Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`.
|
||||
|
||||
|
||||
* `tf.estimator`:
|
||||
* Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`.
|
||||
* Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method.
|
||||
@ -1276,13 +1572,13 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
|
||||
* `Estimator.export_savedmodel` has been renamed to `export_saved_model`.
|
||||
* When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`.
|
||||
* Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead.
|
||||
|
||||
|
||||
* `tf.keras`:
|
||||
* `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs.
|
||||
* `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported.
|
||||
* Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead.
|
||||
* Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer <layer-name>` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information.
|
||||
|
||||
|
||||
* `tf.lite`:
|
||||
* Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API.
|
||||
* Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior.
|
||||
@ -1517,8 +1813,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
|
||||
conversion. TensorRT initialization arguments are now passed wrapped in
|
||||
a named-tuple, `TrtConversionParams`, rather than as separate arguments
|
||||
as in `TrtGraphConverter`.
|
||||
* Changed API to optimize TensorRT enginges during graph optimization.
|
||||
This is now done by calling `converter.build()` where previously
|
||||
* Changed API to optimize TensorRT engines during graph optimization. This
|
||||
is now done by calling `converter.build()` where previously
|
||||
`is_dynamic_op=False` would be set.
|
||||
* `converter.convert()` no longer returns a `tf.function`. Now the
|
||||
function must be accessed from the saved model.
|
||||
@ -2528,7 +2824,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
|
||||
* [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier)
|
||||
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
|
||||
API supports broadcasting for Bijectors with new API changes.
|
||||
|
||||
|
||||
## Breaking Changes
|
||||
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by
|
||||
`variable_scope(tf.get_variable_scope(), ...)`.
|
||||
@ -3007,7 +3303,7 @@ Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim
|
||||
Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan,
|
||||
Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay,
|
||||
Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang,
|
||||
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
|
||||
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
|
||||
|
||||
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||
answered questions, and were part of inspiring discussions.
|
||||
|
11
configure.py
11
configure.py
@ -1163,12 +1163,9 @@ def set_system_libs_flag(environ_cp):
|
||||
syslibs = ','.join(sorted(syslibs.split()))
|
||||
write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
|
||||
|
||||
if 'PREFIX' in environ_cp:
|
||||
write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX'])
|
||||
if 'LIBDIR' in environ_cp:
|
||||
write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR'])
|
||||
if 'INCLUDEDIR' in environ_cp:
|
||||
write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR'])
|
||||
for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'):
|
||||
if varname in environ_cp:
|
||||
write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname]))
|
||||
|
||||
|
||||
def is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
@ -1485,8 +1482,8 @@ def main():
|
||||
'adding "--config=<>" to your build command. See .bazelrc for more '
|
||||
'details.')
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
||||
config_info_line('numa', 'Build with NUMA support.')
|
||||
config_info_line(
|
||||
'dynamic_kernels',
|
||||
|
@ -3,6 +3,7 @@
|
||||
# learning applications.
|
||||
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
@ -22,10 +23,6 @@ load(
|
||||
"//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
|
||||
"TENSORFLOW_API_INIT_FILES_V1", # @unused
|
||||
)
|
||||
load(
|
||||
"//third_party/ngraph:build_defs.bzl",
|
||||
"if_ngraph",
|
||||
)
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl_ml",
|
||||
@ -238,6 +235,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_mips64",
|
||||
values = {"cpu": "mips64"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "debug",
|
||||
values = {
|
||||
@ -465,14 +468,6 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag is set from the configure step when the user selects with nGraph option.
|
||||
# By default it should be false
|
||||
config_setting(
|
||||
name = "with_ngraph_support",
|
||||
values = {"define": "with_ngraph_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag specifies whether TensorFlow 2.0 API should be built instead
|
||||
# of 1.* API. Note that TensorFlow 2.0 API is currently under development.
|
||||
config_setting(
|
||||
@ -497,13 +492,20 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental MLIR bridge support.
|
||||
# This flag forcibly enables experimental MLIR bridge support.
|
||||
config_setting(
|
||||
name = "enable_mlir_bridge",
|
||||
values = {"define": "enable_mlir_bridge=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag forcibly disables experimental MLIR bridge support.
|
||||
config_setting(
|
||||
name = "disable_mlir_bridge",
|
||||
values = {"define": "enable_mlir_bridge=false"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental TPU support
|
||||
config_setting(
|
||||
name = "with_tpu_support",
|
||||
@ -556,35 +558,50 @@ selects.config_setting_group(
|
||||
],
|
||||
)
|
||||
|
||||
# 'enable_registration_v2' opts-in to a different implementation of op and
|
||||
# kernel registration - REGISTER_OP, REGISTER_KERNEL_BUILDER, etc.
|
||||
#
|
||||
# This setting is currently experimental. The 'v2' implementation does _not_
|
||||
# correspond to a particular, finalized design; rather, it relates to
|
||||
# developing one.
|
||||
#
|
||||
# The current aim of the 'v2' implementation is to allow 'unused' ops and
|
||||
# kernels to be discarded by the linker (to the benefit of binary size).
|
||||
bool_flag(
|
||||
name = "enable_registration_v2",
|
||||
build_setting_default = False,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "registration_v1",
|
||||
flag_values = {":enable_registration_v2": "False"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "registration_v2",
|
||||
flag_values = {":enable_registration_v2": "True"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||
# Instead, please use public APIs or public build rules TF provides.
|
||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//learning/brain/distribute/...",
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//learning/lib/ami/simple_ml/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "ndarray_tensor_allow_list",
|
||||
packages = ["//learning/pathways/..."],
|
||||
)
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
package_group(
|
||||
name = "types_whitelist",
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
# If this is modified, then copy.bara.sky must also be modified.
|
||||
package_group(name = "types_whitelist")
|
||||
|
||||
# Packages that use StructuredTensors.
|
||||
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
||||
@ -610,7 +627,7 @@ bzl_library(
|
||||
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
||||
"//third_party/mkl:build_defs_bzl",
|
||||
"//third_party/mkl_dnn:build_defs_bzl",
|
||||
"//third_party/ngraph:build_defs_bzl",
|
||||
"@bazel_skylib//rules:common_settings",
|
||||
"@local_config_cuda//cuda:build_defs_bzl",
|
||||
"@local_config_rocm//rocm:build_defs_bzl",
|
||||
"@local_config_tensorrt//:build_defs_bzl",
|
||||
@ -710,8 +727,12 @@ tf_cc_shared_object(
|
||||
soversion = VERSION,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/c:kernels_hdrs",
|
||||
"//tensorflow/c:ops_hdrs",
|
||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||
"//tensorflow/core:core_cpu_impl",
|
||||
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
@ -813,7 +834,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:tensorflow",
|
||||
] + if_ngraph(["@ngraph_tf//:ngraph_tf"]),
|
||||
],
|
||||
)
|
||||
|
||||
# ** Targets for Windows build (start) **
|
||||
|
@ -138,12 +138,12 @@ if _running_from_pip_package():
|
||||
for _s in _site_packages_dirs:
|
||||
# Load first party dynamic kernels.
|
||||
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
|
||||
if _fi.file_exists(_main_dir):
|
||||
if _os.path.exists(_main_dir):
|
||||
_ll.load_library(_main_dir)
|
||||
|
||||
# Load third party dynamic kernels.
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
||||
# Add module aliases
|
||||
|
@ -148,12 +148,12 @@ if _running_from_pip_package():
|
||||
for _s in _site_packages_dirs:
|
||||
# Load first party dynamic kernels.
|
||||
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
|
||||
if _fi.file_exists(_main_dir):
|
||||
if _os.path.exists(_main_dir):
|
||||
_ll.load_library(_main_dir)
|
||||
|
||||
# Load third party dynamic kernels.
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
|
@ -202,6 +202,7 @@ tf_cuda_library(
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/c/experimental/filesystem:modular_filesystem",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc:gradients",
|
||||
"//tensorflow/cc:ops",
|
||||
@ -217,6 +218,8 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/kernels:logging_ops",
|
||||
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
|
||||
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
@ -254,6 +257,30 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_shape",
|
||||
srcs = ["tf_shape.cc"],
|
||||
hdrs = ["tf_shape.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_macros",
|
||||
":tf_shape_internal",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_shape_internal",
|
||||
hdrs = ["tf_shape_internal.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":conversion_macros",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_status",
|
||||
srcs = ["tf_status.cc"],
|
||||
@ -485,6 +512,18 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kernels_hdrs",
|
||||
hdrs = ["kernels.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "kernels",
|
||||
srcs = [
|
||||
@ -538,6 +577,16 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops_hdrs",
|
||||
hdrs = ["ops.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
@ -2488,6 +2489,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||
TF_Status* status) {
|
||||
using tensorflow::RecordMutation;
|
||||
mutex_lock l(graph->mu);
|
||||
tensorflow::shape_inference::InferenceContext* ic =
|
||||
graph->refiner.GetContext(&new_src.oper->node);
|
||||
|
||||
if (ic->num_outputs() <= new_src.index) {
|
||||
status->status = tensorflow::errors::OutOfRange(
|
||||
"Cannot update edge. Output index [", new_src.index,
|
||||
"] is greater than the number of total outputs [", ic->num_outputs(),
|
||||
"].");
|
||||
return;
|
||||
}
|
||||
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
||||
|
||||
tensorflow::shape_inference::InferenceContext* ic_dst =
|
||||
graph->refiner.GetContext(&dst.oper->node);
|
||||
if (ic_dst->num_inputs() <= dst.index) {
|
||||
status->status = tensorflow::errors::OutOfRange(
|
||||
"Cannot update edge. Input index [", dst.index,
|
||||
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
||||
"].");
|
||||
return;
|
||||
}
|
||||
if (!ic_dst->MergeInput(dst.index, shape)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
||||
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
||||
return;
|
||||
}
|
||||
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
||||
&dst.oper->node, dst.index);
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
// This modification only updates the destination node for
|
||||
// the purposes of running this graph in a session. Thus, we don't
|
||||
// record the source node as being modified.
|
||||
RecordMutation(graph, *dst.oper, "updating input tensor");
|
||||
}
|
||||
}
|
||||
|
||||
// TF_Server functions ----------------------------------------------
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
@ -2564,4 +2607,14 @@ void TF_RegisterLogListener(void (*listener)(const char*)) {
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
}
|
||||
|
||||
void TF_RegisterFilesystemPlugin(const char* plugin_filename,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"FileSystem plugin functionality is not supported on mobile");
|
||||
#else
|
||||
status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
|
||||
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
||||
const char* name, TF_Status* status);
|
||||
|
||||
// Update edge, switch input/ output in a node
|
||||
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
|
||||
TF_Input dst, TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// In-process TensorFlow server functionality, for use in distributed training.
|
||||
// A Server instance encapsulates a set of devices and a Session target that
|
||||
@ -1573,6 +1577,13 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
|
||||
TF_CAPI_EXPORT extern void TF_RegisterLogListener(
|
||||
void (*listener)(const char*));
|
||||
|
||||
// Register a FileSystem plugin from filename `plugin_filename`.
|
||||
//
|
||||
// On success, place OK in status.
|
||||
// On failure, place an error status in status.
|
||||
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
|
||||
const char* plugin_filename, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
collective_executor_handle->get()->StartAbort(status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status) {
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
||||
tensorflow::Notification done;
|
||||
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
|
||||
task, [&done, status](const Status& s) {
|
||||
task, timeout_in_ms, [&done, status](const Status& s) {
|
||||
status->status = s;
|
||||
done.Notify();
|
||||
});
|
||||
|
@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
// Checks the health of collective ops peers. Explicit health check is needed in
|
||||
// multi worker collective ops to detect failures in the cluster. If a peer is
|
||||
// down, collective ops may hang.
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||
TF_Status* status);
|
||||
|
||||
// Information about the shape of a Tensor and its type.
|
||||
struct TF_ShapeAndType {
|
||||
|
@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI, UpdateEdge) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// Make two scalar constants.
|
||||
TF_Operation* one = ScalarConst(1, graph, s, "one");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
TF_Operation* two = ScalarConst(2, graph, s, "two");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Add oper.
|
||||
TF_Operation* add = Add(one, two, graph, s, "add");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Add another oper to the graph.
|
||||
TF_Operation* neg = Neg(add, graph, s, "neg");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
NodeDef node_def_neg;
|
||||
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||
EXPECT_EQ(string("add"), node_def_neg.input(0));
|
||||
|
||||
// update edge of neg
|
||||
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
|
||||
|
||||
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
|
||||
|
||||
// Clean up
|
||||
TF_DeleteGraph(graph);
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
/*
|
||||
TODO(skyewm): this test currently DCHECKs, change to bad status
|
||||
|
||||
|
@ -3,13 +3,16 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_tpu",
|
||||
"if_libtpu",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_cuda_library",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
@ -94,6 +97,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
] + internal_tfrt_deps(),
|
||||
alwayslink = 1,
|
||||
@ -106,6 +110,7 @@ filegroup(
|
||||
"abstract_function.h",
|
||||
"abstract_operation.h",
|
||||
"abstract_tensor_handle.h",
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
@ -116,7 +121,6 @@ filegroup(
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"mnist_gradients_testutil.h",
|
||||
"tape.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_context_internal.h",
|
||||
@ -177,6 +181,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tracing_utils",
|
||||
srcs = ["tracing_utils.cc"],
|
||||
hdrs = [
|
||||
"tracing_utils.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_operation",
|
||||
":c_api_unified_internal",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_operation",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_internal",
|
||||
srcs = [
|
||||
@ -212,6 +234,7 @@ tf_cuda_cc_test(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
@ -222,7 +245,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:array_grad",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
@ -270,7 +294,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
@ -294,6 +318,7 @@ cc_library(
|
||||
":gradients_internal",
|
||||
":gradients_util",
|
||||
":tape",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
@ -334,7 +359,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
@ -618,6 +643,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "tfe_tensorhandle_internal_hdrs_only",
|
||||
extra_deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_test_util",
|
||||
testonly = 1,
|
||||
@ -1037,6 +1075,8 @@ filegroup(
|
||||
"gradients.cc", # Uses RTTI.
|
||||
"gradients_util.cc",
|
||||
"gradients_util.h",
|
||||
"tracing_utils.h",
|
||||
"tracing_utils.cc",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt };
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt };
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -70,6 +70,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -729,7 +730,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
@ -752,8 +753,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
/*device_mgr_owned*/ true, r));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -856,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
// TODO(yuefengz): support partially specified `worker_name`.
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
status->status = context->GetClient(worker_name, &eager_client);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
status->status =
|
||||
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
|
||||
return false;
|
||||
}
|
||||
tensorflow::WorkerInterface* wi =
|
||||
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
|
||||
if (wi == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to find worker interface corresponding to task ", worker_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Send a rpc request to the worker to check aliveness.
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context->GetContextId());
|
||||
tensorflow::eager::KeepAliveResponse response;
|
||||
|
||||
tensorflow::Status keep_alive_status;
|
||||
tensorflow::GetStatusRequest request;
|
||||
tensorflow::GetStatusResponse response;
|
||||
tensorflow::Status remote_status;
|
||||
tensorflow::Notification done;
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &response,
|
||||
[&keep_alive_status, &done](const tensorflow::Status& s) {
|
||||
keep_alive_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
|
||||
[&remote_status, &done](const tensorflow::Status& s) {
|
||||
remote_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
|
||||
// We set OK status so the call does not raise any exceptions. Instead, caller
|
||||
// users the return value to tell if the remote worker is alive.
|
||||
status->status = tensorflow::Status::OK();
|
||||
|
||||
// If `context_id` doesn't exist on the remote worker, an InvalidArgument
|
||||
// error will return. But this still indicates that the remote worker is
|
||||
// alive.
|
||||
if (keep_alive_status.ok() ||
|
||||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
||||
if (remote_status.ok()) {
|
||||
return true;
|
||||
} else {
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << keep_alive_status.error_message();
|
||||
return false;
|
||||
}
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << remote_status.error_message();
|
||||
return false;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -905,9 +906,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalDevicePlacementPolicy(
|
||||
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
|
||||
@ -916,10 +915,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
context->GetDevicePlacementPolicy());
|
||||
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
@ -1430,21 +1427,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->FindFunctionDef(name) != nullptr;
|
||||
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@ -1456,13 +1447,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
auto* context = tensorflow::unwrap(ctx);
|
||||
status->status = context->AsyncWait();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
||||
context->ClearRunMetadata();
|
||||
auto run_metadata = context->ExportRunMetadata();
|
||||
status->status = MessageToBuffer(*run_metadata, buf);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
} TFE_ContextDevicePlacementPolicy;
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
|
||||
|
||||
// Sets the default execution mode (sync/async). Note that this can be
|
||||
// overridden per thread using TFE_ContextSetExecutorForThread.
|
||||
|
@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
// TODO(b/170399182): Update test once an alternative to using the function
|
||||
// optimization hook is in place.
|
||||
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
|
@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
}
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return new TFE_Executor(&context->Executor());
|
||||
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
context->HostCPU()->parsed_name());
|
||||
tensorflow::unwrap(ctx)->HostCPUParsedName());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
@ -643,14 +631,26 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::unwrap(h)->DeviceType(&status->status);
|
||||
}
|
||||
|
||||
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
return tensorflow::unwrap(h)->DeviceId(&status->status);
|
||||
}
|
||||
|
@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the device type of the operation that produced `h`.
|
||||
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
// Returns the device ID of the operation that produced `h`.
|
||||
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleNullptr) {
|
||||
TFE_TensorHandle* h = nullptr;
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_type, nullptr);
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
int device_id = TFE_TensorHandleDeviceID(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_id, -1);
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
|
||||
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
|
||||
|
||||
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
TFE_DeleteOp(shape_op);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleDefaults) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
|
||||
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
|
||||
h_default, ctx, "/device:CPU:0", status.get());
|
||||
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
|
||||
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
|
||||
|
||||
TFE_DeleteTensorHandle(h_default);
|
||||
TFE_DeleteTensorHandle(h_cpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically unstable
|
||||
enable_tensor_float_32_execution(false);
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
|
@ -122,14 +122,12 @@ int64 ToId(AbstractTensorHandle* t) {
|
||||
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
|
||||
}
|
||||
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
|
||||
: handle_(handle), ctx_(ctx) {
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
|
||||
handle_->Ref();
|
||||
}
|
||||
TapeTensor::TapeTensor(const TapeTensor& other) {
|
||||
handle_ = other.handle_;
|
||||
handle_->Ref();
|
||||
ctx_ = other.ctx_;
|
||||
}
|
||||
TapeTensor::~TapeTensor() { handle_->Unref(); }
|
||||
|
||||
@ -138,33 +136,7 @@ tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
|
||||
tensorflow::DataType TapeTensor::GetDType() const {
|
||||
return handle_->DataType();
|
||||
}
|
||||
|
||||
AbstractTensorHandle* TapeTensor::OnesLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
|
||||
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
|
||||
|
||||
@ -219,6 +191,23 @@ Status TapeVSpace::CallBackwardFunction(
|
||||
&ctx, incoming_gradients, result);
|
||||
}
|
||||
|
||||
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(
|
||||
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
|
||||
*result = outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
|
||||
return ToId(tensor);
|
||||
@ -226,7 +215,7 @@ int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
|
||||
return TapeTensor(g, ctx_);
|
||||
return TapeTensor(g);
|
||||
}
|
||||
|
||||
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
|
||||
@ -426,7 +415,7 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
forward_op_->attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t, ctx));
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
tape->RecordOperation(
|
||||
op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
|
@ -80,7 +80,6 @@ struct ForwardOperation {
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
std::vector<AbstractTensorHandle*> outputs;
|
||||
AttrBuilder attrs;
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
// Interface for building default zeros gradients for op outputs which are
|
||||
@ -181,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
|
||||
// allow us to trace the data dependencies between operations and hence compute
|
||||
// gradients.
|
||||
//
|
||||
// This also implements `OnesLike` to create the default
|
||||
// incoming gradients for tensors which do not already have an incoming
|
||||
// gradient.
|
||||
//
|
||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||
// for each op.
|
||||
@ -193,20 +188,19 @@ int64 ToId(AbstractTensorHandle* t);
|
||||
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
|
||||
class TapeTensor {
|
||||
public:
|
||||
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
|
||||
explicit TapeTensor(AbstractTensorHandle* handle);
|
||||
TapeTensor(const TapeTensor& other);
|
||||
~TapeTensor();
|
||||
|
||||
tensorflow::int64 GetID() const;
|
||||
tensorflow::DataType GetDType() const;
|
||||
|
||||
AbstractTensorHandle* OnesLike() const;
|
||||
AbstractTensorHandle* ZerosLike() const;
|
||||
|
||||
AbstractTensorHandle* GetHandle() const;
|
||||
|
||||
private:
|
||||
AbstractTensorHandle* handle_;
|
||||
// The context where OnesLike ops are to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// Vector space for actually computing gradients. Implements methods for calling
|
||||
@ -234,6 +228,10 @@ class TapeVSpace
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
Status BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const override;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TensorId(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
@ -26,7 +27,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
@ -53,73 +56,18 @@ class CppGradients
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
|
||||
// AddV2Registerer.
|
||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `exp(inputs[0])` and records it on the tape.
|
||||
Status Exp(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr exp_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(exp_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `IdentityN(inputs)` and records it on the tape.
|
||||
Status IdentityN(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(identity_n_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
|
||||
->SetOpName("my_identity_n"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
|
||||
int num_retvals = outputs.size();
|
||||
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
|
||||
tape, registry);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
@ -128,12 +76,14 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(add_outputs),
|
||||
"Add")); // Compute x+y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -149,7 +99,6 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -161,11 +110,12 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> exp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -179,7 +129,36 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
exp_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = sqrt(inputs[0])
|
||||
// return grad(y, {inputs[0]})
|
||||
Status SqrtGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sqrt_output : sqrt_outputs) {
|
||||
sqrt_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -192,13 +171,14 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
tape->Watch(ToId(inputs[1]));
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
|
||||
absl::MakeSpan(identity_n_outputs), registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::IdentityN(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -214,6 +194,105 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = - inputs[0]
|
||||
// return grad(y, {inputs[0]})
|
||||
Status NegGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto neg_output : neg_outputs) {
|
||||
neg_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] - inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status SubGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> sub_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(sub_outputs),
|
||||
"Sub")); // Compute x-y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sub_output : sub_outputs) {
|
||||
sub_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(mul_outputs),
|
||||
"Mul")); // Compute x*y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(mul_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto mul_output : mul_outputs) {
|
||||
mul_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -452,6 +531,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSqrtGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = sqrt(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// Pseudo-code:
|
||||
//
|
||||
@ -511,6 +634,172 @@ TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestNegGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = - x
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, -1.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSubGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x - y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, -1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x * y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 2.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSetAttrString) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -533,7 +822,6 @@ TEST_P(CppGradients, TestSetAttrString) {
|
||||
|
||||
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx.get();
|
||||
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
|
||||
/*raw_device_name=*/nullptr, &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
@ -551,7 +839,7 @@ TEST_P(CppGradients, TestSetAttrString) {
|
||||
int num_retvals = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
GradientRegistry registry;
|
||||
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
|
||||
&num_retvals, &forward_op, tape.get(), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
@ -29,8 +29,26 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class EagerExecutor;
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||
enum ContextDevicePlacementPolicy {
|
||||
// Running operations with input tensors on the wrong device will fail.
|
||||
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
// Copy the tensor to the right device but log a warning.
|
||||
DEVICE_PLACEMENT_WARN = 1,
|
||||
// Silently copy the tensor, which has a performance cost since the operation
|
||||
// will be blocked till the copy completes. This is the default policy.
|
||||
DEVICE_PLACEMENT_SILENT = 2,
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
@ -81,14 +99,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
@ -97,11 +107,56 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
|
||||
// Find and return a added function by its name.
|
||||
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||
|
||||
// Return the ParsedName of Host CPU device.
|
||||
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||
|
||||
// Configure soft device placement policy.
|
||||
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||
|
||||
// Configure device placement policy logging.
|
||||
virtual void SetLogDevicePlacement(bool enable) = 0;
|
||||
|
||||
// Sets the device placement policy for the current thread.
|
||||
virtual void SetThreadLocalDevicePlacementPolicy(
|
||||
ContextDevicePlacementPolicy policy) = 0;
|
||||
// Returns the device placement policy for the current thread.
|
||||
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||
|
||||
// Configure graph collection in RunMetadata.
|
||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||
|
||||
// Return the collected RunMetadata. This method will transfer the ownership
|
||||
// to the caller.
|
||||
virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are legacy features in TF Eager Runtime.
|
||||
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||
// migrated to TFRT.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Clear pending nodes in thread executors and kernel caches.
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Return the Eager Executor for current thread. Please note that Eager
|
||||
// Executor is only used in current TF but not in TFRT.
|
||||
virtual EagerExecutor& Executor() = 0;
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
virtual const char* DeviceName(Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(Status* status) const = 0;
|
||||
// Returns the device type which created the handle.
|
||||
virtual const char* DeviceType(Status* status) const = 0;
|
||||
// Returns the device ID which created the handle.
|
||||
virtual int DeviceId(Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -43,6 +44,11 @@ class CppGradients
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
|
||||
// low tolerances
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -25,138 +25,18 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(mul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(relu_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
|
||||
// one-hot) and records it on the tape.
|
||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* scores = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(sm_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
||||
|
||||
int num_retvals = 2; // returns loss values and backprop
|
||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
//===================== Test Models to run =========================
|
||||
|
||||
@ -172,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -205,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute x*y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(mm_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute x*y.
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -261,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W2)); // Watch W2.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs), "relu",
|
||||
registry)); // Compute Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"relu")); // Compute Relu(X*W1)
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // Compute W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
|
||||
"matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||
|
||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||
|
||||
@ -302,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W1));
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/true,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/true,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
@ -320,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch X
|
||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
|
||||
"relu0", registry)); // Relu(X)
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(relu_outputs),
|
||||
"relu0")); // Relu(X)
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -351,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -386,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
tape->Watch(ToId(W2)); // Watch W1.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
AbstractTensorHandle* mm = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0", registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0"));
|
||||
|
||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss")); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* loss = temp_outputs[0];
|
||||
|
||||
@ -445,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0", registry)); // Compute eta*A
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0")); // Compute eta*A
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
@ -464,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
@ -483,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
|
||||
"mul0", registry)); // Compute x*y
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"mul0")); // Compute x*y
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
@ -501,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
|
||||
registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
|
||||
|
||||
outputs[0] = temp_outputs[0]; // loss values
|
||||
|
||||
|
@ -29,45 +29,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
||||
// tape.
|
||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// ====================== End Tape Ops ============================
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
|
@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
class DeviceThread {
|
||||
public:
|
||||
// Starts a background thread waiting for `StartExecute`.
|
||||
explicit DeviceThread(const std::string& device)
|
||||
explicit DeviceThread(const std::string& device, const bool is_async)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
// If the context's default exector is set to async, re-using that in
|
||||
@ -67,7 +67,7 @@ class DeviceThread {
|
||||
//
|
||||
// TODO(allenl): We should have an async API that works with the
|
||||
// parallel device.
|
||||
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||
executor_(TFE_NewExecutor(is_async)),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
}
|
||||
}
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
|
||||
const bool is_async)
|
||||
: underlying_devices_(devices) {
|
||||
device_threads_.reserve(devices.size());
|
||||
for (int device_index = 0; device_index < devices.size(); ++device_index) {
|
||||
device_threads_.emplace_back(
|
||||
new DeviceThread(devices[device_index].c_str()));
|
||||
new DeviceThread(devices[device_index].c_str(), is_async));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,10 @@ class DeviceThread;
|
||||
// placed on each underlying device.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
// Eager async execution is only supported when remote eager is not in use
|
||||
// (b/157523095).
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices,
|
||||
const bool is_async = false);
|
||||
|
||||
~ParallelDevice();
|
||||
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -98,6 +99,10 @@ class VSpace {
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) const = 0;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
virtual Status BuildOnesLike(const TapeTensor& t,
|
||||
Gradient** result) const = 0;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
virtual int64 TensorId(Gradient* tensor) const = 0;
|
||||
|
||||
@ -121,7 +126,7 @@ class GradientTape {
|
||||
// functions (and hence the tensors they keep alive). Instead, everything
|
||||
// is deleted in ~GradientTape. Persistent GradientTapes are useful when
|
||||
// users want to compute multiple gradients over the same tape.
|
||||
GradientTape(bool persistent) : persistent_(persistent) {}
|
||||
explicit GradientTape(bool persistent) : persistent_(persistent) {}
|
||||
~GradientTape() {
|
||||
for (const auto& pair : op_tape_) {
|
||||
pair.second.backward_function_deleter(pair.second.backward_function);
|
||||
@ -595,8 +600,10 @@ Status InitialGradients(
|
||||
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
|
||||
if (op_it->second.output_tensor_info[j].GetID() == id) {
|
||||
found = true;
|
||||
(*result)[id].push_back(
|
||||
op_it->second.output_tensor_info[j].OnesLike());
|
||||
Gradient* ones_like = nullptr;
|
||||
TF_RETURN_IF_ERROR(vspace.BuildOnesLike(
|
||||
op_it->second.output_tensor_info[j], &ones_like));
|
||||
(*result)[id].push_back(ones_like);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -611,7 +618,10 @@ Status InitialGradients(
|
||||
// target is also a source.
|
||||
auto source_tensor = sources_that_are_targets.find(id);
|
||||
if (source_tensor != sources_that_are_targets.end()) {
|
||||
(*result)[id].push_back(source_tensor->second.OnesLike());
|
||||
Gradient* ones_like = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
vspace.BuildOnesLike(source_tensor->second, &ones_like));
|
||||
(*result)[id].push_back(ones_like);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -934,7 +944,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
// TODO(allenl): Figure out why using zeros_like everywhere causes issues
|
||||
// for some gradient functions and if there's another way to work around
|
||||
// it (e.g. conds instead of ifs). The value shouldn't really matter.
|
||||
aid = output_tensor.OnesLike();
|
||||
TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid));
|
||||
}
|
||||
if (TF_PREDICT_FALSE(aid == nullptr)) {
|
||||
return tensorflow::errors::Internal(
|
||||
|
37
tensorflow/c/eager/tracing_utils.cc
Normal file
37
tensorflow/c/eager/tracing_utils.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tracing {
|
||||
|
||||
Status MaybeSetOpName(AbstractOperation* op, const char* op_name) {
|
||||
if (isa<TracingOperation>(op)) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(op)->SetOpName(op_name));
|
||||
}
|
||||
if (isa<gradients::TapeOperation>(op)) {
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(
|
||||
dyn_cast<gradients::TapeOperation>(op)->GetBackingOperation(),
|
||||
op_name));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tracing
|
||||
} // namespace tensorflow
|
26
tensorflow/c/eager/tracing_utils.h
Normal file
26
tensorflow/c/eager/tracing_utils.h
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tracing {
|
||||
Status MaybeSetOpName(AbstractOperation*, const char* op_name);
|
||||
} // namespace tracing
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_UTILS_H_
|
@ -29,6 +29,7 @@ cc_library(
|
||||
}),
|
||||
deps = [
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//third_party/hadoop:hdfs",
|
||||
|
@ -22,11 +22,10 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "third_party/hadoop/hdfs.h"
|
||||
|
||||
// Implementation of a filesystem for HADOOP environments.
|
||||
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
|
||||
@ -149,15 +148,20 @@ class LibHDFS {
|
||||
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
||||
if (hdfs_home != nullptr) {
|
||||
auto JoinPath = [](std::string home, std::string lib) {
|
||||
#if defined(_WIN32)
|
||||
if (home.back() != '\\') home.push_back('\\');
|
||||
return home + "lib\\native\\" + lib;
|
||||
#else
|
||||
if (home.back() != '/') home.push_back('/');
|
||||
return home + "lib/native/" + lib;
|
||||
#endif
|
||||
};
|
||||
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
||||
TryLoadAndBind(path.c_str(), &handle_, status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
return;
|
||||
} else {
|
||||
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
|
||||
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
|
||||
}
|
||||
}
|
||||
|
||||
@ -169,16 +173,17 @@ class LibHDFS {
|
||||
void* handle_;
|
||||
};
|
||||
|
||||
// We rely on HDFS connection caching here. The HDFS client calls
|
||||
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
|
||||
// internally.
|
||||
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
||||
// We implement connection caching in Tensorflow, which can significantly
|
||||
// improve performance. Fixes #43187
|
||||
hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
||||
const std::string& path, TF_Status* status) {
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
||||
|
||||
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
||||
std::string cacheKey(scheme);
|
||||
if (scheme == "file") {
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
|
||||
namenode = "";
|
||||
} else if (scheme == "viewfs") {
|
||||
char* defaultFS = nullptr;
|
||||
libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS);
|
||||
@ -194,21 +199,33 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
||||
// The default NameNode configuration will be used (from the XML
|
||||
// configuration files). See:
|
||||
// https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, "default");
|
||||
namenode = "default";
|
||||
} else if (scheme == "har") {
|
||||
std::string path_har = path;
|
||||
SplitArchiveNameAndPath(&path_har, &namenode, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
|
||||
} else {
|
||||
libhdfs->hdfsBuilderSetNameNode(
|
||||
builder, namenode.empty() ? "default" : namenode.c_str());
|
||||
if (namenode.empty()) {
|
||||
namenode = "default";
|
||||
}
|
||||
}
|
||||
auto fs = libhdfs->hdfsBuilderConnect(builder);
|
||||
if (fs == nullptr)
|
||||
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
cacheKey += namenode;
|
||||
|
||||
absl::MutexLock l(&hadoop_file->connection_cache_lock);
|
||||
if (hadoop_file->connection_cache.find(cacheKey) ==
|
||||
hadoop_file->connection_cache.end()) {
|
||||
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
||||
libhdfs->hdfsBuilderSetNameNode(
|
||||
builder, namenode.empty() ? nullptr : namenode.c_str());
|
||||
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
||||
if (cacheFs == nullptr) {
|
||||
TF_SetStatusFromIOError(status, TF_ABORTED, strerror(errno));
|
||||
return cacheFs;
|
||||
}
|
||||
hadoop_file->connection_cache[cacheKey] = cacheFs;
|
||||
}
|
||||
auto fs = hadoop_file->connection_cache[cacheKey];
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return fs;
|
||||
}
|
||||
|
||||
@ -222,6 +239,7 @@ typedef struct HDFSFile {
|
||||
LibHDFS* libhdfs;
|
||||
absl::Mutex mu;
|
||||
hdfsFile handle ABSL_GUARDED_BY(mu);
|
||||
bool disable_eof_retried;
|
||||
HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs,
|
||||
hdfsFile handle)
|
||||
: path(std::move(path)),
|
||||
@ -229,7 +247,15 @@ typedef struct HDFSFile {
|
||||
fs(fs),
|
||||
libhdfs(libhdfs),
|
||||
mu(),
|
||||
handle(handle) {}
|
||||
handle(handle) {
|
||||
const char* disable_eof_retried_str =
|
||||
getenv("HDFS_DISABLE_READ_EOF_RETRIED");
|
||||
if (disable_eof_retried_str && disable_eof_retried_str[0] == '1') {
|
||||
disable_eof_retried = true;
|
||||
} else {
|
||||
disable_eof_retried = false;
|
||||
}
|
||||
}
|
||||
} HDFSFile;
|
||||
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
@ -253,6 +279,10 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
|
||||
char* dst = buffer;
|
||||
bool eof_retried = false;
|
||||
if (hdfs_file->disable_eof_retried) {
|
||||
// eof_retried = true, avoid calling hdfsOpenFile in Read, Fixes #42597
|
||||
eof_retried = true;
|
||||
}
|
||||
int64_t read = 0;
|
||||
while (TF_GetCode(status) == TF_OK && n > 0) {
|
||||
// We lock inside the loop rather than outside so we don't block other
|
||||
@ -396,30 +426,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
// Hadoop doesn't support Readonly Memory Region
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_hadoop_filesystem {
|
||||
|
||||
HadoopFile::HadoopFile(TF_Status* status)
|
||||
: libhdfs(new LibHDFS(status)),
|
||||
connection_cache_lock(),
|
||||
connection_cache() {}
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
filesystem->plugin_filesystem = new LibHDFS(status);
|
||||
filesystem->plugin_filesystem = new HadoopFile(status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Cleanup(TF_Filesystem* filesystem) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
delete libhdfs;
|
||||
delete hadoop_file;
|
||||
}
|
||||
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -435,8 +471,9 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -452,8 +489,9 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -484,8 +522,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
|
||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -500,8 +539,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -519,8 +559,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -540,8 +581,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -555,8 +597,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -570,8 +613,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -606,8 +650,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
const char* dst, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, src, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
|
||||
@ -627,8 +672,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
|
||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
char*** entries, TF_Status* status) {
|
||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
||||
auto fs = Connect(libhdfs, path, status);
|
||||
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||
auto libhdfs = hadoop_file->libhdfs;
|
||||
auto fs = Connect(hadoop_file, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
|
||||
std::string scheme, namenode, hdfs_path;
|
||||
@ -664,7 +710,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
return num_entries;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||
return strdup(uri);
|
||||
}
|
||||
|
||||
} // namespace tf_hadoop_filesystem
|
||||
|
||||
@ -672,6 +720,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_hadoop_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file =
|
||||
tf_hadoop_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_hadoop_filesystem::NewAppendableFile;
|
||||
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
|
||||
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
|
||||
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
|
||||
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
|
||||
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
|
||||
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
|
||||
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
|
||||
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
|
||||
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
|
@ -15,10 +15,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "third_party/hadoop/hdfs.h"
|
||||
|
||||
void ParseHadoopPath(const std::string& fname, std::string* scheme,
|
||||
std::string* namenode, std::string* path);
|
||||
@ -43,6 +46,14 @@ void Close(const TF_WritableFile* file, TF_Status* status);
|
||||
} // namespace tf_writable_file
|
||||
|
||||
namespace tf_hadoop_filesystem {
|
||||
typedef struct HadoopFile {
|
||||
LibHDFS* libhdfs;
|
||||
absl::Mutex connection_cache_lock;
|
||||
std::map<std::string, hdfsFS> connection_cache
|
||||
ABSL_GUARDED_BY(connection_cache_lock);
|
||||
HadoopFile(TF_Status* status);
|
||||
} HadoopFile;
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
@ -352,6 +352,48 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) {
|
||||
EXPECT_TF_OK(status_);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) {
|
||||
static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1";
|
||||
putenv(set_disable_var);
|
||||
|
||||
const std::string path = TmpDir("ReadWhileOverwriting");
|
||||
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
|
||||
|
||||
const string content1 = "content1";
|
||||
WriteString(path, content1);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
auto reader = GetReader();
|
||||
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||
reader.get(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(content1.size());
|
||||
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
|
||||
&result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(content1, result);
|
||||
|
||||
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
|
||||
EXPECT_TF_OK(status_);
|
||||
|
||||
string content2 = "overwrite";
|
||||
WriteString(path, content1 + content2);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
result.resize(content2.size());
|
||||
read = tf_random_access_file::Read(reader.get(), content1.size(),
|
||||
content2.size(), &result[0], status_);
|
||||
result.resize(read);
|
||||
EXPECT_TF_OK(status_);
|
||||
EXPECT_EQ(0, result.size());
|
||||
|
||||
static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0";
|
||||
putenv(set_enable_var);
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, HarSplit) {
|
||||
const std::string har_path =
|
||||
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
|
||||
|
@ -24,6 +24,8 @@ using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::Neg;
|
||||
using tensorflow::ops::SqrtGrad;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -72,6 +74,25 @@ class ExpGradientFunction : public GradientFunction {
|
||||
AbstractTensorHandlePtr exp_;
|
||||
};
|
||||
|
||||
class SqrtGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||
sqrt->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
std::string name = "Sqrt_Grad";
|
||||
grad_outputs->resize(1);
|
||||
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~SqrtGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractTensorHandlePtr sqrt_;
|
||||
};
|
||||
|
||||
class MatMulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
@ -181,6 +202,93 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
AttrBuilder forward_attrs;
|
||||
};
|
||||
|
||||
class NegGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a Neg op Y = -X, the gradients are:
|
||||
*
|
||||
* dX = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(1);
|
||||
std::string name = "Neg_Grad";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~NegGradientFunction() override {}
|
||||
};
|
||||
|
||||
class SubGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a Sub op A-B, the gradients are:
|
||||
*
|
||||
* dA = U
|
||||
* dB = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Grad for A
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
(*grad_outputs)[0]->Ref();
|
||||
|
||||
// Grad for B
|
||||
// negate the upstream grad
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
std::string name = "Neg_Sub_Grad_B";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(neg_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = neg_outputs[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~SubGradientFunction() override {}
|
||||
};
|
||||
|
||||
class MulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MulGradientFunction(vector<AbstractTensorHandle*> f_inputs)
|
||||
: forward_inputs(f_inputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a mul op A*B, the gradients are:
|
||||
*
|
||||
* dA = U * B
|
||||
* dB = A * U
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
|
||||
// Gradient for A
|
||||
std::string name = "Mul_Grad_A";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {upstream_grad, forward_inputs[1]},
|
||||
absl::MakeSpan(mul_outputs), name.c_str()));
|
||||
(*grad_outputs)[0] = mul_outputs[0];
|
||||
|
||||
// Gradient for B
|
||||
name = "Mul_Grad_B";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {forward_inputs[0], upstream_grad},
|
||||
absl::MakeSpan(mul_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = mul_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~MulGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
@ -210,5 +318,41 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new NegGradientFunction;
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new SubGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* MulRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new MulGradientFunction(op.inputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -19,10 +19,16 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MulRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
|
66
tensorflow/c/experimental/gradients/tape/BUILD
Normal file
66
tensorflow/c/experimental/gradients/tape/BUILD
Normal file
@ -0,0 +1,66 @@
|
||||
# A tape built on top of unified execution APIs.
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape_context",
|
||||
srcs = ["tape_context.cc"],
|
||||
hdrs = [
|
||||
"tape_context.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tape_operation",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape_operation",
|
||||
srcs = ["tape_operation.cc"],
|
||||
hdrs = [
|
||||
"tape_operation.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape",
|
||||
hdrs = [
|
||||
"tape_context.h",
|
||||
"tape_operation.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tape_context",
|
||||
":tape_operation",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"tape_context.h",
|
||||
"tape_operation.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
47
tensorflow/c/experimental/gradients/tape/tape_context.cc
Normal file
47
tensorflow/c/experimental/gradients/tape/tape_context.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
TapeContext::TapeContext(AbstractContext* c, Tape* tape,
|
||||
const GradientRegistry& registry)
|
||||
: AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) {
|
||||
// TODO(srbs): Make AbstractContext ref counted.
|
||||
// parent_ctx_->Ref();
|
||||
}
|
||||
void TapeContext::Release() {
|
||||
// TODO(srbs): Change to Unref()
|
||||
delete this;
|
||||
}
|
||||
TapeContext::~TapeContext() {
|
||||
// TODO(srbs): Make AbstractContext ref counted.
|
||||
// parent_ctx_->Unref();
|
||||
}
|
||||
TapeOperation* TapeContext::CreateOperation() {
|
||||
return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_);
|
||||
}
|
||||
Status TapeContext::RegisterFunction(AbstractFunction* f) {
|
||||
return parent_ctx_->RegisterFunction(f);
|
||||
}
|
||||
Status TapeContext::RemoveFunction(const string& func) {
|
||||
return parent_ctx_->RemoveFunction(func);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
44
tensorflow/c/experimental/gradients/tape/tape_context.h
Normal file
44
tensorflow/c/experimental/gradients/tape/tape_context.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
class TapeContext : public AbstractContext {
|
||||
public:
|
||||
explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&);
|
||||
void Release() override;
|
||||
TapeOperation* CreateOperation() override;
|
||||
Status RegisterFunction(AbstractFunction*) override;
|
||||
Status RemoveFunction(const string& func) override;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kTape;
|
||||
}
|
||||
~TapeContext() override;
|
||||
|
||||
private:
|
||||
AbstractContext* parent_ctx_; // Not owned.
|
||||
Tape* tape_;
|
||||
const GradientRegistry& registry_;
|
||||
};
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
238
tensorflow/c/experimental/gradients/tape/tape_operation.cc
Normal file
238
tensorflow/c/experimental/gradients/tape/tape_operation.cc
Normal file
@ -0,0 +1,238 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
|
||||
const GradientRegistry& registry)
|
||||
: AbstractOperation(kTape),
|
||||
parent_op_(parent_op),
|
||||
tape_(tape),
|
||||
registry_(registry) {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// parent_op_->Ref();
|
||||
}
|
||||
void TapeOperation::Release() {
|
||||
// TODO(srbs): Change to Unref().
|
||||
delete this;
|
||||
}
|
||||
TapeOperation::~TapeOperation() {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// parent_op->Unref();
|
||||
}
|
||||
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
|
||||
forward_op_.op_name = op;
|
||||
forward_op_.attrs.Reset(op);
|
||||
forward_op_.inputs.clear();
|
||||
forward_op_.outputs.clear();
|
||||
return parent_op_->Reset(op, raw_device_name);
|
||||
}
|
||||
const string& TapeOperation::Name() const { return parent_op_->Name(); }
|
||||
const string& TapeOperation::DeviceName() const {
|
||||
return parent_op_->DeviceName();
|
||||
}
|
||||
Status TapeOperation::SetDeviceName(const char* name) {
|
||||
return parent_op_->SetDeviceName(name);
|
||||
}
|
||||
Status TapeOperation::AddInput(AbstractTensorHandle* input) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
|
||||
forward_op_.inputs.push_back(input);
|
||||
return Status::OK();
|
||||
}
|
||||
Status TapeOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs));
|
||||
for (auto input : inputs) {
|
||||
forward_op_.inputs.push_back(input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status TapeOperation::SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) {
|
||||
forward_op_.attrs.Set(attr_name, StringPiece(data, length));
|
||||
return parent_op_->SetAttrString(attr_name, data, length);
|
||||
}
|
||||
Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
forward_op_.attrs.Set(attr_name, static_cast<int64>(value));
|
||||
return parent_op_->SetAttrInt(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrFloat(const char* attr_name, float value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrFloat(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrBool(const char* attr_name, bool value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrBool(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrType(const char* attr_name, DataType value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrType(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
forward_op_.attrs.Set(attr_name, proto);
|
||||
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
Status TapeOperation::SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunction has not been implemented yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrFunctionName(const char* attr_name,
|
||||
const char* value, size_t length) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented "
|
||||
"yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTensor has not been implemented yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths, int num_values) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
forward_op_.attrs.Set(attr_name, v);
|
||||
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrFloatList(const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const float>(values, num_values));
|
||||
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
forward_op_.attrs.Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return parent_op_->SetAttrIntList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrTypeList(const char* attr_name,
|
||||
const DataType* values, int num_values) {
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const DataType>(values, num_values));
|
||||
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims, int num_values) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
forward_op_.attrs.Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrFunctionList(
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionList has not been "
|
||||
"implemented yet.");
|
||||
}
|
||||
AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
|
||||
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
|
||||
std::vector<int64> input_ids(forward_op_.inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(forward_op_.inputs.size());
|
||||
for (int i = 0; i < forward_op_.inputs.size(); i++) {
|
||||
input_ids[i] = ToId(forward_op_.inputs[i]);
|
||||
input_dtypes[i] = forward_op_.inputs[i]->DataType();
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; i++) {
|
||||
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
|
||||
forward_op_.outputs.push_back(retvals[i]);
|
||||
}
|
||||
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
|
||||
// attributes. Number type attrs and DataType attrs work fine without this.
|
||||
// Consider getting rid of this and making the behavior between number types
|
||||
// and string consistent.
|
||||
forward_op_.attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
tape_->RecordOperation(
|
||||
parent_op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[this]() -> BackwardFunction* {
|
||||
std::unique_ptr<BackwardFunction> backward_fn;
|
||||
Status s = registry_.Lookup(forward_op_, &backward_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return backward_fn.release();
|
||||
},
|
||||
[](BackwardFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
80
tensorflow/c/experimental/gradients/tape/tape_operation.h
Normal file
80
tensorflow/c/experimental/gradients/tape/tape_operation.h
Normal file
@ -0,0 +1,80 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
class TapeOperation : public AbstractOperation {
|
||||
public:
|
||||
explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&);
|
||||
void Release() override;
|
||||
Status Reset(const char* op, const char* raw_device_name) override;
|
||||
const string& Name() const override;
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values) override;
|
||||
AbstractOperation* GetBackingOperation();
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kTape;
|
||||
}
|
||||
~TapeOperation() override;
|
||||
|
||||
private:
|
||||
AbstractOperation* parent_op_;
|
||||
ForwardOperation forward_op_;
|
||||
Tape* tape_;
|
||||
const GradientRegistry& registry_;
|
||||
};
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
@ -22,7 +22,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/c/eager:tracing_utils",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
@ -43,8 +43,8 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:tracing_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
@ -64,7 +64,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/c/eager:tracing_utils",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
@ -87,7 +87,6 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,9 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
using tensorflow::tracing::MaybeSetOpName;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
@ -26,24 +28,30 @@ Status Identity(AbstractContext* ctx,
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(identity_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return identity_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status IdentityN(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_n_op->Reset("IdentityN", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(identity_n_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(identity_n_op->AddInputList(inputs));
|
||||
int num_retvals = inputs.size();
|
||||
return identity_n_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status ZerosLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr z_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(z_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(z_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(z_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return z_op->Execute(outputs, &num_retvals);
|
||||
@ -54,12 +62,7 @@ Status Shape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr shape_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(shape_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(shape_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(shape_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals));
|
||||
@ -71,10 +74,7 @@ Status ExpandDims(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("ExpandDims", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name));
|
||||
TF_RETURN_IF_ERROR(op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(op->AddInput(inputs[1]));
|
||||
int num_retvals = 1;
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
@ -27,6 +26,10 @@ Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status IdentityN(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status ZerosLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
@ -16,22 +16,21 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
using tensorflow::tracing::MaybeSetOpName;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
|
||||
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
|
||||
if (isa<TracingOperation>(mul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(mul_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
|
||||
int num_retvals = 1;
|
||||
@ -55,12 +54,7 @@ Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(add_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
|
||||
|
||||
@ -73,12 +67,7 @@ Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sub_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(sub_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sub_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sub_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
|
||||
|
||||
@ -93,12 +82,7 @@ Status MatMul(AbstractContext* ctx,
|
||||
bool transpose_a = false, bool transpose_b = false) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(matmul_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
|
||||
|
||||
@ -114,10 +98,7 @@ Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr neg_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
|
||||
if (isa<TracingOperation>(neg_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(neg_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(neg_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
@ -128,12 +109,7 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sum_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(sum_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sum_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sum_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices
|
||||
|
||||
@ -147,12 +123,7 @@ Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr div_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(div_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(div_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
|
||||
|
||||
@ -162,5 +133,44 @@ Status DivNoNan(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr exp_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(exp_op->Reset("Exp", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(exp_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(exp_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
return exp_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -47,6 +47,18 @@ Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -15,8 +15,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
using tensorflow::tracing::MaybeSetOpName;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
@ -27,12 +30,7 @@ Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
AbstractOperationPtr sm_loss_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(sm_loss_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sm_loss_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sm_loss_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores
|
||||
TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels
|
||||
|
||||
@ -49,12 +47,7 @@ Status ReluGrad(AbstractContext* ctx,
|
||||
AbstractOperationPtr relugrad_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(relugrad_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(relugrad_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(relugrad_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads
|
||||
TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs
|
||||
|
||||
@ -68,12 +61,7 @@ Status Relu(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(relu_op->Reset("Relu", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(relu_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(relu_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(relu_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(relu_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
@ -66,12 +66,18 @@ cc_library(
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
@ -85,15 +91,24 @@ cc_library(
|
||||
":signature_def_function_metadata",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_metadata",
|
||||
srcs = [
|
||||
"signature_def_function_metadata.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"signature_def_function_metadata.h",
|
||||
],
|
||||
deps = [
|
||||
":tensor_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -146,12 +161,14 @@ cc_library(
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -216,6 +233,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -259,6 +277,20 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_spec",
|
||||
srcs = [
|
||||
"tensor_spec.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tensor_spec.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tf_concrete_function_loading_test",
|
||||
srcs = [
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
@ -300,80 +301,70 @@ nodes {
|
||||
|
||||
TEST(ObjectGraphTraversalTest, Success) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_NE(nullptr, obj);
|
||||
EXPECT_EQ(obj->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(obj->user_object().identifier(), "_generic_user_object");
|
||||
absl::optional<int> node = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_TRUE(node.has_value());
|
||||
EXPECT_EQ(*node, 1);
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, ObjectNotFound) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("bar", object_graph);
|
||||
EXPECT_EQ(nullptr, obj);
|
||||
absl::optional<int> node = internal::FindNodeAtPath("bar", object_graph);
|
||||
EXPECT_FALSE(node.has_value());
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, CaseSensitiveMismatch) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("FOO", object_graph);
|
||||
EXPECT_EQ(nullptr, obj);
|
||||
absl::optional<int> node = internal::FindNodeAtPath("FOO", object_graph);
|
||||
EXPECT_FALSE(node.has_value());
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, NestedObjectFound) {
|
||||
SavedObjectGraph object_graph =
|
||||
ParseSavedObjectGraph(kSingleChildFooWithFuncBar);
|
||||
const SavedObject* obj = internal::FindNodeAtPath("foo.bar", object_graph);
|
||||
ASSERT_NE(nullptr, obj);
|
||||
EXPECT_EQ(obj->kind_case(), SavedObject::kFunction);
|
||||
EXPECT_EQ(obj->function().concrete_functions_size(), 1);
|
||||
EXPECT_EQ(obj->function().concrete_functions(0), "__inference_my_func_5");
|
||||
absl::optional<int> node = internal::FindNodeAtPath("foo.bar", object_graph);
|
||||
ASSERT_TRUE(node.has_value());
|
||||
EXPECT_EQ(*node, 2);
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, MultiplePathsAliasSameObject) {
|
||||
SavedObjectGraph object_graph = ParseSavedObjectGraph(kMultiplePathsToChild);
|
||||
const SavedObject* foo_baz =
|
||||
absl::optional<int> foo_baz_node =
|
||||
internal::FindNodeAtPath("foo.baz", object_graph);
|
||||
ASSERT_NE(nullptr, foo_baz);
|
||||
EXPECT_EQ(foo_baz->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_baz->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(foo_baz_node.has_value());
|
||||
EXPECT_EQ(*foo_baz_node, 4);
|
||||
|
||||
const SavedObject* bar_wombat =
|
||||
absl::optional<int> bar_wombat_node =
|
||||
internal::FindNodeAtPath("bar.wombat", object_graph);
|
||||
ASSERT_NE(nullptr, bar_wombat);
|
||||
EXPECT_EQ(bar_wombat->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(bar_wombat->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(bar_wombat_node.has_value());
|
||||
EXPECT_EQ(*bar_wombat_node, 4);
|
||||
|
||||
EXPECT_EQ(foo_baz, bar_wombat);
|
||||
EXPECT_EQ(*foo_baz_node, *bar_wombat_node);
|
||||
}
|
||||
|
||||
TEST(ObjectGraphTraversalTest, CyclesAreOK) {
|
||||
SavedObjectGraph object_graph =
|
||||
ParseSavedObjectGraph(kCycleBetweenParentAndChild);
|
||||
const SavedObject* foo = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_NE(nullptr, foo);
|
||||
EXPECT_EQ(foo->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo->user_object().identifier(), "_generic_user_object");
|
||||
absl::optional<int> foo = internal::FindNodeAtPath("foo", object_graph);
|
||||
ASSERT_TRUE(foo.has_value());
|
||||
EXPECT_EQ(*foo, 1);
|
||||
|
||||
const SavedObject* foo_bar =
|
||||
absl::optional<int> foo_bar =
|
||||
internal::FindNodeAtPath("foo.bar", object_graph);
|
||||
ASSERT_NE(nullptr, foo_bar);
|
||||
EXPECT_EQ(foo_bar->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_bar->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(foo_bar.has_value());
|
||||
EXPECT_EQ(*foo_bar, 3);
|
||||
|
||||
const SavedObject* foo_bar_parent =
|
||||
absl::optional<int> foo_bar_parent =
|
||||
internal::FindNodeAtPath("foo.bar.parent", object_graph);
|
||||
ASSERT_NE(nullptr, foo_bar_parent);
|
||||
EXPECT_EQ(foo_bar_parent->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_bar_parent->user_object().identifier(), "_generic_user_object");
|
||||
ASSERT_TRUE(foo_bar_parent.has_value());
|
||||
EXPECT_EQ(*foo_bar_parent, 1);
|
||||
|
||||
const SavedObject* foo_bar_parent_bar =
|
||||
absl::optional<int> foo_bar_parent_bar =
|
||||
internal::FindNodeAtPath("foo.bar.parent.bar", object_graph);
|
||||
ASSERT_NE(nullptr, foo_bar_parent_bar);
|
||||
EXPECT_EQ(foo_bar_parent_bar->kind_case(), SavedObject::kUserObject);
|
||||
EXPECT_EQ(foo_bar_parent_bar->user_object().identifier(),
|
||||
"_generic_user_object");
|
||||
ASSERT_TRUE(foo_bar_parent_bar.has_value());
|
||||
EXPECT_EQ(*foo_bar_parent_bar, 3);
|
||||
|
||||
EXPECT_EQ(foo, foo_bar_parent);
|
||||
EXPECT_EQ(foo_bar, foo_bar_parent_bar);
|
||||
EXPECT_EQ(*foo, *foo_bar_parent);
|
||||
EXPECT_EQ(*foo_bar, *foo_bar_parent_bar);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -69,6 +69,86 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "partially_revived_objects",
|
||||
srcs = [
|
||||
"partially_revived_objects.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"partially_revived_objects.h",
|
||||
],
|
||||
deps = [
|
||||
":asset",
|
||||
":constant",
|
||||
":restored_resource",
|
||||
":restored_resource_revival_state",
|
||||
":revived_objects",
|
||||
":tf_concrete_function",
|
||||
":tf_concrete_function_revival_state",
|
||||
":tf_signature_def_function",
|
||||
":tf_signature_def_function_revival_state",
|
||||
":variable",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restored_resource",
|
||||
srcs = [
|
||||
"restored_resource.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"restored_resource.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
":tf_concrete_function",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restored_resource_revival_state",
|
||||
hdrs = [
|
||||
"restored_resource_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_concrete_function_revival_state",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "revived_objects",
|
||||
hdrs = [
|
||||
"revived_objects.h",
|
||||
],
|
||||
deps = [
|
||||
":asset",
|
||||
":constant",
|
||||
":restored_resource",
|
||||
":tf_concrete_function",
|
||||
":tf_signature_def_function",
|
||||
":variable",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable",
|
||||
srcs = [
|
||||
@ -86,6 +166,8 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -123,6 +205,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_concrete_function_revival_state",
|
||||
hdrs = [
|
||||
"tf_concrete_function_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_concrete_function",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_signature_def_function",
|
||||
srcs = [
|
||||
@ -145,3 +242,17 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_signature_def_function_revival_state",
|
||||
hdrs = [
|
||||
"tf_signature_def_function_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_signature_def_function",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -33,8 +33,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
FlatTensorFunction::FlatTensorFunction(
|
||||
const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
const std::string& name, std::vector<ImmediateTensorHandlePtr> captures,
|
||||
ImmediateExecutionContext* ctx)
|
||||
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
|
||||
|
||||
@ -51,8 +50,15 @@ Status FlatTensorFunction::Create(
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
|
||||
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
||||
std::vector<ImmediateTensorHandlePtr> owned_captures;
|
||||
owned_captures.reserve(captures.size());
|
||||
for (ImmediateExecutionTensorHandle* capture : captures) {
|
||||
capture->Ref();
|
||||
owned_captures.push_back(ImmediateTensorHandlePtr(capture));
|
||||
}
|
||||
|
||||
out->reset(new FlatTensorFunction(function_def->signature().name(),
|
||||
std::move(captures), ctx));
|
||||
std::move(owned_captures), ctx));
|
||||
return Status();
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,9 @@ class FlatTensorFunction {
|
||||
// destruction. function_def must be non-null, but
|
||||
// otherwise has no lifetime requirements.
|
||||
// captures - The captured TensorHandles associated with this
|
||||
// FlatTensorFunction.
|
||||
// FlatTensorFunction. FlatTensorFunction will participate in
|
||||
// ownership of the handles (it explicitly increments the refcount
|
||||
// of each handle, and will decrement them on destruction).
|
||||
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
|
||||
// outlive TFConcreteFunction.
|
||||
// out - The output FlatTensorFunction.
|
||||
@ -67,7 +69,7 @@ class FlatTensorFunction {
|
||||
|
||||
private:
|
||||
FlatTensorFunction(const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
std::vector<ImmediateTensorHandlePtr> captures,
|
||||
ImmediateExecutionContext* ctx);
|
||||
|
||||
FlatTensorFunction(const FlatTensorFunction&) = delete;
|
||||
@ -75,7 +77,7 @@ class FlatTensorFunction {
|
||||
|
||||
// Name of the FunctionDef corresponding to this TFConcreteFunction
|
||||
std::string name_;
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures_;
|
||||
std::vector<ImmediateTensorHandlePtr> captures_;
|
||||
ImmediateExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,543 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using StructuredValueDictEntry =
|
||||
protobuf::MapPair<std::string, StructuredValue>;
|
||||
|
||||
using NamedParamMap =
|
||||
gtl::FlatMap<StringPiece, const TensorSpecProto*, StringPieceHasher>;
|
||||
|
||||
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
||||
const PartiallyRevivedObjects& objects) {
|
||||
for (const auto& id_and_resource : objects.restored_resources) {
|
||||
int node_id = id_and_resource.first;
|
||||
const RestoredResourceRevivalState& resource = id_and_resource.second;
|
||||
const TFConcreteFunctionRevivalState* create_resource_fn =
|
||||
resource.create_resource;
|
||||
if (create_resource_fn == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Resource at node ", node_id,
|
||||
" did not have a create_resource() function");
|
||||
}
|
||||
const SavedConcreteFunction* saved_create_resource_fn =
|
||||
create_resource_fn->saved_concrete_func;
|
||||
if (!saved_create_resource_fn->bound_inputs().empty()) {
|
||||
// TODO(b/124045874): Support loading resource functions via a top sort
|
||||
return errors::Unimplemented(
|
||||
"Create Resource functions with captures are currently unsupported.");
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Retrieves the TensorHandle associated with `node_id` from `obj_graph`, and
|
||||
// set `*handle` to point to it.
|
||||
Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
ImmediateExecutionTensorHandle** handle) {
|
||||
const SavedObject& node = obj_graph.nodes(node_id);
|
||||
SavedObject::KindCase kind = node.kind_case();
|
||||
switch (kind) {
|
||||
case SavedObject::kVariable: {
|
||||
const auto& variables_iter = objects.variables.find(node_id);
|
||||
if (variables_iter == objects.variables.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type variable to tensor but the variable wasn't initialized");
|
||||
}
|
||||
*handle = variables_iter->second->handle();
|
||||
return Status();
|
||||
}
|
||||
case SavedObject::kConstant: {
|
||||
const auto& constants_iter = objects.constants.find(node_id);
|
||||
if (constants_iter == objects.constants.end()) {
|
||||
return errors::FailedPrecondition("Tried to convert node id ", node_id,
|
||||
" of type constant to tensor but the "
|
||||
"constant wasn't initialized");
|
||||
}
|
||||
*handle = constants_iter->second->handle();
|
||||
return Status();
|
||||
}
|
||||
case SavedObject::kAsset: {
|
||||
const auto& assets_iter = objects.assets.find(node_id);
|
||||
if (assets_iter == objects.assets.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type asset to tensor but the asset wasn't initialized");
|
||||
}
|
||||
*handle = assets_iter->second->handle();
|
||||
return Status();
|
||||
}
|
||||
case SavedObject::kResource: {
|
||||
const auto& resource_iter = objects.restored_resources.find(node_id);
|
||||
if (resource_iter == objects.restored_resources.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type Resource to tensor but the Resource wasn't initialized");
|
||||
}
|
||||
const RestoredResourceRevivalState& resource = resource_iter->second;
|
||||
if (resource.resource_handle == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Resource with node id ", node_id,
|
||||
" should have its resource_handle created, but was nullptr.");
|
||||
}
|
||||
*handle = resource.resource_handle.get();
|
||||
return Status();
|
||||
}
|
||||
default: {
|
||||
return errors::FailedPrecondition(
|
||||
"Only objects of type variable, constant, asset, and resources have "
|
||||
"capturable tensorhandles. Encountered object of kind ",
|
||||
node.kind_case(), " at node id: ", node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<SignatureDefParam> SignatureDefParamsFromNamedParamMap(
|
||||
const NamedParamMap& params) {
|
||||
// The underlying functiondef associated with the SignatureDef has
|
||||
// nest.flattened inputs and outputs, which are sorted by string key.
|
||||
std::vector<SignatureDefParam> result;
|
||||
result.reserve(params.size());
|
||||
for (const auto& named_param : params) {
|
||||
result.push_back(SignatureDefParam(std::string(named_param.first),
|
||||
TensorSpec(*named_param.second)));
|
||||
}
|
||||
std::sort(result.begin(), result.end(),
|
||||
[](const SignatureDefParam& x, const SignatureDefParam& y) {
|
||||
return x.name() < y.name();
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// SignatureDefArgsFromInputs takes the "canonicalized_input_signature"
|
||||
// field of a SavedConcreteFunction, ensures it conforms to the structure of
|
||||
// tuple(tuple(), dict<string,TensorSpec>()), and "returns" a list of
|
||||
// SignatureDefParams of the SignatureDefFunction's arguments.
|
||||
Status SignatureDefArgsFromInputs(
|
||||
const StructuredValue& canonicalized_input_signature,
|
||||
std::vector<SignatureDefParam>* out) {
|
||||
// Note(bmzhao): canonicalized_input_signature should be a tuple of
|
||||
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
|
||||
// string keys to TensorSpecs.
|
||||
if (!canonicalized_input_signature.has_tuple_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||
"of form tuple(tuple(), dict()), but was instead: \n",
|
||||
canonicalized_input_signature.DebugString());
|
||||
}
|
||||
|
||||
const TupleValue& args_kwargs_tuple =
|
||||
canonicalized_input_signature.tuple_value();
|
||||
if (args_kwargs_tuple.values_size() != 2) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||
"a tuple of two elements (args, kwargs), but was instead: \n",
|
||||
args_kwargs_tuple.DebugString());
|
||||
}
|
||||
|
||||
const StructuredValue& args = args_kwargs_tuple.values(0);
|
||||
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature's args"
|
||||
"should be an empty tuple, but instead got: \n",
|
||||
args.DebugString());
|
||||
}
|
||||
|
||||
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
|
||||
if (!kwargs.has_dict_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||
"should be a dictionary, but instead got: \n",
|
||||
kwargs.DebugString());
|
||||
}
|
||||
|
||||
const DictValue& kwargs_dict = kwargs.dict_value();
|
||||
NamedParamMap result;
|
||||
result.reserve(kwargs_dict.fields_size());
|
||||
|
||||
for (const auto& key_value : kwargs_dict.fields()) {
|
||||
const std::string& key = key_value.first;
|
||||
const StructuredValue& value = key_value.second;
|
||||
if (!value.has_tensor_spec_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||
"dictionary contained a non-tensorspec value for key-value pair: \n",
|
||||
"Key: ", key, "Value: \n", value.DebugString());
|
||||
}
|
||||
result[key] = &value.tensor_spec_value();
|
||||
}
|
||||
|
||||
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
// SignatureDefReturnsFromOutputs takes the "output_signature" field of a
|
||||
// SavedConcreteFunction, ensures it conforms to the structure of
|
||||
// dict<string,TensorSpec>(), and "returns" a list of SignatureDefParams of the
|
||||
// SignatureDefFunction's returns.
|
||||
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
|
||||
std::vector<SignatureDefParam>* out) {
|
||||
if (!output_signature.has_dict_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's output_signature must be a dictionary, but "
|
||||
"instead got: ",
|
||||
output_signature.DebugString());
|
||||
}
|
||||
|
||||
const DictValue& output_dict = output_signature.dict_value();
|
||||
NamedParamMap result;
|
||||
result.reserve(output_dict.fields_size());
|
||||
|
||||
for (const auto& key_value : output_dict.fields()) {
|
||||
const std::string& key = key_value.first;
|
||||
const StructuredValue& value = key_value.second;
|
||||
if (!value.has_tensor_spec_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's output_signature dictionary contained a "
|
||||
"non-tensorspec value for key-value pair: \n",
|
||||
"Key: ", key, "Value: \n", value.DebugString());
|
||||
}
|
||||
result[key] = &value.tensor_spec_value();
|
||||
}
|
||||
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
// The implementation takes advantage of the fact that SignatureDefFunction's
|
||||
// "traced" Signature wrapper function always has inputs/outputs of dictionaries
|
||||
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126
|
||||
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178
|
||||
// Additionally, we take advantage of the fact that the SignatureDefFunction's
|
||||
// associated functiondef has lexicographically ordered inputs/outputs due to
|
||||
// nest.flatten.
|
||||
Status LoadSignatureDefFunctionMetadata(
|
||||
const SavedConcreteFunction& saved_concrete_function,
|
||||
SignatureDefFunctionMetadata* out) {
|
||||
std::vector<SignatureDefParam> args;
|
||||
TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs(
|
||||
saved_concrete_function.canonicalized_input_signature(), &args));
|
||||
|
||||
std::vector<SignatureDefParam> rets;
|
||||
TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs(
|
||||
saved_concrete_function.output_signature(), &rets));
|
||||
|
||||
*out = SignatureDefFunctionMetadata(std::move(args), std::move(rets));
|
||||
return Status();
|
||||
}
|
||||
|
||||
// This function finds the necessary captures, then forwards to the builder
|
||||
// method
|
||||
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
|
||||
const TFConcreteFunctionRevivalState& builder,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
std::unique_ptr<TFConcreteFunction>* out) {
|
||||
const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs();
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
captures.reserve(capture_node_ids.size());
|
||||
for (int capture_node_id : capture_node_ids) {
|
||||
ImmediateExecutionTensorHandle* capture_handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects,
|
||||
&capture_handle));
|
||||
captures.push_back(capture_handle);
|
||||
}
|
||||
// TODO(bmzhao): Create Metadata here
|
||||
return TFConcreteFunction::Create(/*function_def=*/builder.fdef,
|
||||
/*captures=*/std::move(captures),
|
||||
/*metadata=*/{},
|
||||
/*ctx=*/ctx,
|
||||
/*out=*/out);
|
||||
}
|
||||
|
||||
Status CreateSignatureDefFunction(
|
||||
ImmediateExecutionContext* ctx,
|
||||
const TFSignatureDefFunctionRevivalState& builder,
|
||||
const SavedObjectGraph& obj_graph, const PartiallyRevivedObjects& objects,
|
||||
std::unique_ptr<TFSignatureDefFunction>* out) {
|
||||
const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs();
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
captures.reserve(capture_node_ids.size());
|
||||
for (int capture_node_id : capture_node_ids) {
|
||||
ImmediateExecutionTensorHandle* capture_handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects,
|
||||
&capture_handle));
|
||||
captures.push_back(capture_handle);
|
||||
}
|
||||
|
||||
SignatureDefFunctionMetadata metadata;
|
||||
TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata(
|
||||
*builder.saved_concrete_func, &metadata));
|
||||
|
||||
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
|
||||
/*captures=*/std::move(captures),
|
||||
/*metadata=*/std::move(metadata),
|
||||
/*ctx=*/ctx,
|
||||
/*out=*/out);
|
||||
}
|
||||
|
||||
Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
RevivedObjects* revived) {
|
||||
for (const auto& id_and_resource : objects.restored_resources) {
|
||||
const RestoredResourceRevivalState& resource = id_and_resource.second;
|
||||
const TFConcreteFunctionRevivalState* create_resource_fn =
|
||||
resource.create_resource;
|
||||
|
||||
const SavedConcreteFunction* saved_create_resource_fn =
|
||||
create_resource_fn->saved_concrete_func;
|
||||
if (!saved_create_resource_fn->bound_inputs().empty()) {
|
||||
// TODO(b/124045874): Load resource functions via a topological sort
|
||||
return errors::Unimplemented(
|
||||
"Create Resource functions with captures are currently unsupported.");
|
||||
}
|
||||
std::unique_ptr<TFConcreteFunction> out;
|
||||
TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn,
|
||||
obj_graph, objects, &out));
|
||||
revived->concrete_functions[create_resource_fn->node_id] = std::move(out);
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
RevivedObjects* revived) {
|
||||
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>* destination_func_map =
|
||||
&revived->concrete_functions;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>*
|
||||
destination_sig_map = &revived->signature_def_functions;
|
||||
|
||||
for (const auto& id_and_func : objects.concrete_functions) {
|
||||
int node_id = id_and_func.first;
|
||||
const TFConcreteFunctionRevivalState& func = id_and_func.second;
|
||||
|
||||
if (destination_func_map->find(node_id) != destination_func_map->end()) {
|
||||
// The function has already been initialized in the destination_map,
|
||||
// so we can skip this node. This can occur because we initialize
|
||||
// CreateResource functions before calling this function.
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateConcreteFunction(ctx, func, obj_graph, objects, &out));
|
||||
(*destination_func_map)[node_id] = std::move(out);
|
||||
}
|
||||
|
||||
for (const auto& id_and_func : objects.signature_def_functions) {
|
||||
int node_id = id_and_func.first;
|
||||
const TFSignatureDefFunctionRevivalState& func = id_and_func.second;
|
||||
|
||||
if (destination_sig_map->find(node_id) != destination_sig_map->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFSignatureDefFunction> out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateSignatureDefFunction(ctx, func, obj_graph, objects, &out));
|
||||
(*destination_sig_map)[node_id] = std::move(out);
|
||||
}
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
PartiallyRevivedObjects* objects,
|
||||
RevivedObjects* revived) {
|
||||
for (auto& id_and_resource : objects->restored_resources) {
|
||||
RestoredResourceRevivalState& resource = id_and_resource.second;
|
||||
int create_resource_fn_node = resource.create_resource->node_id;
|
||||
const gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>&
|
||||
revived_functions = revived->concrete_functions;
|
||||
|
||||
const auto& revived_functions_iter =
|
||||
revived_functions.find(create_resource_fn_node);
|
||||
if (revived_functions_iter == revived_functions.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"ConcreteFunction at node ", create_resource_fn_node,
|
||||
" should have been initialized prior to being called.");
|
||||
}
|
||||
const TFConcreteFunction& create_resource_fn =
|
||||
*revived_functions_iter->second;
|
||||
ImmediateOpPtr function_op;
|
||||
TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op));
|
||||
TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str()));
|
||||
|
||||
AbstractTensorHandle* resource_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(function_op->Execute(
|
||||
absl::MakeSpan(&resource_handle, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_resource_handle(resource_handle);
|
||||
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
|
||||
owned_resource_handle.get())) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
ImmediateTensorHandlePtr result(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
owned_resource_handle.release()));
|
||||
resource.resource_handle = std::move(result);
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to
|
||||
// point to it. If node doesn't exist in `objects`, out is untouched, and an
|
||||
// error status is returned.
|
||||
Status FindConcreteFunction(int node, RevivedObjects* objects,
|
||||
TFConcreteFunction** out) {
|
||||
auto func_iter = objects->concrete_functions.find(node);
|
||||
if (func_iter == objects->concrete_functions.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Failed to find ConcreteFunction with node id ", node,
|
||||
" in revived objects");
|
||||
}
|
||||
*out = func_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status BuildResources(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
PartiallyRevivedObjects* objects,
|
||||
RevivedObjects* revived) {
|
||||
for (auto& id_and_resource : objects->restored_resources) {
|
||||
int node_id = id_and_resource.first;
|
||||
RestoredResourceRevivalState& resource_revival_state =
|
||||
id_and_resource.second;
|
||||
|
||||
TFConcreteFunction* create_resource = nullptr;
|
||||
|
||||
// Check all the functions associated with the resource have already been
|
||||
// initialized in `revived`
|
||||
if (resource_revival_state.create_resource != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindConcreteFunction(resource_revival_state.create_resource->node_id,
|
||||
revived, &create_resource));
|
||||
}
|
||||
|
||||
TFConcreteFunction* initialize = nullptr;
|
||||
if (resource_revival_state.initialize != nullptr) {
|
||||
TF_RETURN_IF_ERROR(FindConcreteFunction(
|
||||
resource_revival_state.initialize->node_id, revived, &initialize));
|
||||
}
|
||||
|
||||
TFConcreteFunction* destroy_resource = nullptr;
|
||||
if (resource_revival_state.destroy_resource != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindConcreteFunction(resource_revival_state.destroy_resource->node_id,
|
||||
revived, &destroy_resource));
|
||||
}
|
||||
|
||||
if (resource_revival_state.resource_handle == nullptr) {
|
||||
return errors::FailedPrecondition("Resource at node id ", node_id,
|
||||
" does not have a resource handle.");
|
||||
}
|
||||
|
||||
revived->restored_resources.emplace(
|
||||
node_id, RestoredResource(
|
||||
/*device=*/resource_revival_state.device,
|
||||
/*create_resource=*/create_resource,
|
||||
/*initialize=*/initialize,
|
||||
/*destroy_resource=*/destroy_resource,
|
||||
/*resource_handle=*/
|
||||
std::move(resource_revival_state.resource_handle)));
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
RevivedObjects* revived) {
|
||||
// Step 1: We would like to initialize all functions; this requires setting up
|
||||
// their captured tensorhandles, which may come from variables, assets,
|
||||
// constants, or resources. The first three are trivial; However,
|
||||
// tensorhandles that correspond to resources must be created by invoking
|
||||
// their "create_resource" function.
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
|
||||
// For now, we assert that all create_resource functions must have no
|
||||
// captures. This aligns with the current behavior in python.
|
||||
// https://github.com/tensorflow/tensorflow/blob/50eac986bf7a0ad12594e080f083181f277e0b49/tensorflow/python/saved_model/load.py#L152-L155
|
||||
// TODO(bmzhao): We should do a topological sort instead.
|
||||
|
||||
// 1a. Make sure all CreateResource functions have no captures.
|
||||
TF_RETURN_IF_ERROR(AssertAllCreateResourceFunctionsHaveNoCaptures(*this));
|
||||
|
||||
// 1b. Initialize all CreateResource functions, storing them in `revived`
|
||||
TF_RETURN_IF_ERROR(
|
||||
InitializeCreateResourceFunctions(ctx, obj_graph, *this, revived));
|
||||
|
||||
// 1c. Invoke all "CreateResource" functions and store their ResourceHandles
|
||||
// https://github.com/tensorflow/tensorflow/blob/3b6b41b68a95dc70c26dc816b29d359bfb88c116/tensorflow/python/training/tracking/tracking.py#L241-L247
|
||||
// in *this->resources.
|
||||
// TODO(bmzhao): Maybe store them separately, not in *this?
|
||||
TF_RETURN_IF_ERROR(CreateAllResourceHandles(ctx, obj_graph, this, revived));
|
||||
|
||||
// 2. Initialize all the rest of the functions
|
||||
TF_RETURN_IF_ERROR(InitializeAllFunctions(ctx, obj_graph, *this, revived));
|
||||
|
||||
// 3a. Move over all non-function, non-resource objects
|
||||
revived->variables = std::move(variables);
|
||||
revived->assets = std::move(assets);
|
||||
revived->constants = std::move(constants);
|
||||
revived->signatures_map = std::move(signatures_map);
|
||||
|
||||
// 3b. Move over resources.
|
||||
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Container for objects during the revival step in SavedModel's loading.
|
||||
// Notably, resources and functions can be in a state where they reference
|
||||
// other resources/functions that have not been constructed yet. We collect
|
||||
// *all* objects in a partially valid state here, then properly initialize
|
||||
// resources and functions. Implementation-wise, PartiallyRevivedObjects
|
||||
// contains maps keyed by the node number of the SavedObjectGraph, and map to an
|
||||
// object of the corresponding type. So, if node 2 in the object graph is a
|
||||
// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a
|
||||
// tensorflow::Variable object. The only exception to this is the
|
||||
// "signatures_map", which is keyed by the "signature" key
|
||||
// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89),
|
||||
// and maps to the SignatureDefFunction node in the SavedObjectGraph.
|
||||
struct PartiallyRevivedObjects {
|
||||
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
||||
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
||||
gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
|
||||
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
|
||||
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
|
||||
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
|
||||
gtl::FlatMap<std::string, int> signatures_map;
|
||||
|
||||
Status Build(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph, RevivedObjects* revived);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status ExecuteNoArgDummyReturnFunction(TFConcreteFunction* func) {
|
||||
ImmediateOpPtr function_op;
|
||||
TF_RETURN_IF_ERROR(func->MakeCallOp({}, &function_op));
|
||||
|
||||
AbstractTensorHandle* dummy_output = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(function_op->Execute(
|
||||
absl::MakeSpan(&dummy_output, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_dummy_output(dummy_output);
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RestoredResource::RestoredResource(const std::string& device,
|
||||
TFConcreteFunction* create_resource,
|
||||
TFConcreteFunction* initialize,
|
||||
TFConcreteFunction* destroy_resource,
|
||||
ImmediateTensorHandlePtr resource_handle)
|
||||
: TensorHandleConvertible(std::move(resource_handle)),
|
||||
device_(device),
|
||||
create_resource_(create_resource),
|
||||
initialize_(initialize),
|
||||
destroy_resource_(destroy_resource) {}
|
||||
|
||||
Status RestoredResource::Initialize() const {
|
||||
return ExecuteNoArgDummyReturnFunction(initialize_);
|
||||
}
|
||||
|
||||
RestoredResource::~RestoredResource() {
|
||||
// Note(bmzhao): SavedModels saved before
|
||||
// https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3
|
||||
// did not have their destroy_resource function saved, meaning they will
|
||||
// leak resources.
|
||||
if (destroy_resource_ != nullptr) {
|
||||
Status status = ExecuteNoArgDummyReturnFunction(destroy_resource_);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING)
|
||||
<< "Failed executing destroy_resource function for RestoredResource: "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RestoredResource represents a TF2 "Resource" object loaded from a savedmodel,
|
||||
// analogous to the Python _RestoredResource object:
|
||||
// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/saved_model/load.py#L481
|
||||
// TF2 resource objects typically extend TrackableResource:
|
||||
// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/training/tracking/tracking.py#L285
|
||||
// and are expected to implement "_create_resource", "_initialize", and
|
||||
// "_destroy_resource" functions:
|
||||
// https://github.com/tensorflow/tensorflow/blob/139ba9c5284799beafdd1d7f895127cf00e7c48f/tensorflow/python/training/tracking/tracking.py#L262-L281
|
||||
class RestoredResource : TensorHandleConvertible {
|
||||
public:
|
||||
// Note(bmzhao): RestoredResource stores non-owning pointers to its associated
|
||||
// functions because SavedModel internally owns all functions and objects in
|
||||
// the RevivedObjects struct (which owns all functions). One alternative would
|
||||
// be to have RevivedObjects store shared_ptr<TFConcreteFunction> instead, and
|
||||
// change RestoredResource's constructor take shared_ptr<TFConcreteFunction>.
|
||||
// To keep things simple, I've stuck to raw pointers for now.
|
||||
//
|
||||
// Params:
|
||||
// device - The device string associated with the SavedResource
|
||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_object_graph.proto#L182
|
||||
// Conceptually, this is the same device used in CapturableResource:
|
||||
// https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L222-L225
|
||||
// Implementation-wise, it is device used when invoking the
|
||||
// create_resource function to produce the resource_handle
|
||||
// associated with the object:
|
||||
// https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L246-L247
|
||||
// create_resource - Non owning pointer to the create_resource function
|
||||
// associated with this object. Must be NON-NULL.
|
||||
// initialize - Non owning pointer to the initialize function associated with
|
||||
// this object. Must be NON-NULL.
|
||||
// destroy_resource - Non owning pointer to the destroy_resource function
|
||||
// associated with this object. Ideally this should be
|
||||
// NON-NULL, but in order to support models saved prior to
|
||||
// https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3
|
||||
// we allow null here. This will, however, leak resources.
|
||||
RestoredResource(const std::string& device,
|
||||
TFConcreteFunction* create_resource,
|
||||
TFConcreteFunction* initialize,
|
||||
TFConcreteFunction* destroy_resource,
|
||||
ImmediateTensorHandlePtr resource_handle);
|
||||
|
||||
Status Initialize() const;
|
||||
|
||||
// RestoredResource is movable, but not copyable.
|
||||
RestoredResource(RestoredResource&& other) = default;
|
||||
RestoredResource& operator=(RestoredResource&& other) = default;
|
||||
|
||||
~RestoredResource() override;
|
||||
|
||||
private:
|
||||
std::string device_;
|
||||
TFConcreteFunction* create_resource_;
|
||||
TFConcreteFunction* initialize_;
|
||||
TFConcreteFunction* destroy_resource_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
|
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// All "Resources" should have these 3 saved functions:
|
||||
// https://github.com/tensorflow/tensorflow/blob/86dc281333d7d277ddc1882f2bca4b17e7ec40e5/tensorflow/python/training/tracking/tracking.py#L277-L281
|
||||
struct RestoredResourceRevivalState {
|
||||
std::string device;
|
||||
TFConcreteFunctionRevivalState* create_resource = nullptr;
|
||||
TFConcreteFunctionRevivalState* initialize = nullptr;
|
||||
TFConcreteFunctionRevivalState* destroy_resource = nullptr;
|
||||
ImmediateTensorHandlePtr resource_handle = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RevivedObjects is mainly used as a container for all the "state" owned by
|
||||
// SavedModel. It stores all non-"user object" nodes from a SavedModel
|
||||
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62)
|
||||
// in a "fully constructed" state. It is effectively a strongly typed map, where
|
||||
// each member is a map from the node id in the SavedObjectGraph's nodes
|
||||
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29)
|
||||
// to the revived object of the corresponding type.
|
||||
struct RevivedObjects {
|
||||
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
||||
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
||||
gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>> concrete_functions;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
|
||||
signature_def_functions;
|
||||
gtl::FlatMap<int, RestoredResource> restored_resources;
|
||||
gtl::FlatMap<std::string, int> signatures_map;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
|
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TFConcreteFunctionRevivalState wraps the state needed for building a
|
||||
// TF_ConcreteFunction. This is mainly used in PartiallyRevivedObjects, which
|
||||
// wraps partially constructed Function and Resource objects.
|
||||
struct TFConcreteFunctionRevivalState {
|
||||
// Index of the node in the SavedObjectGraph it was loaded from.
|
||||
int node_id;
|
||||
|
||||
// Pointer to the original functiondef. fdef_ is guaranteed to be
|
||||
// non-null.
|
||||
const FunctionDef* fdef;
|
||||
|
||||
// TensorHandle captures for this funtion
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
|
||||
// SavedConcreteFunction contains much of the metadata of the expected "types"
|
||||
// of the inputs and outputs of a function.
|
||||
// Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null.
|
||||
const SavedConcreteFunction* saved_concrete_func;
|
||||
|
||||
// This field is only present on TF2 ConcreteFunctions, and is useful for
|
||||
// determining the original argument *names* of the function, (since the
|
||||
// "canonicalized_input_signature" may append extra uniquifying integers).
|
||||
// However, SavedBareConcreteFunctions do not have a FunctionSpec.
|
||||
// Note(bmzhao): if function_spec_.has_value(), *function_spec_ is guaranteed
|
||||
// to be non-null.
|
||||
absl::optional<const FunctionSpec*> function_spec;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// FunctionBuilder wraps the state needed for building a SignatureDefFunction.
|
||||
// This is mainly used in PartiallyRevivedObjects, which wraps partially
|
||||
// constructed Function and Resource objects.
|
||||
struct TFSignatureDefFunctionRevivalState {
|
||||
// Index of the node in the SavedObjectGraph it was loaded from.
|
||||
int node_id = 0;
|
||||
|
||||
// Pointer to the original functiondef. fdef_ is guaranteed to be
|
||||
// non-null.
|
||||
const FunctionDef* fdef = nullptr;
|
||||
|
||||
// SavedConcreteFunction contains much of the metadata of the expected "types"
|
||||
// of the inputs and outputs of a function.
|
||||
// Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null.
|
||||
const SavedConcreteFunction* saved_concrete_func = nullptr;
|
||||
|
||||
// The name of the SignatureDef key.
|
||||
std::string signature_key;
|
||||
|
||||
// TensorHandle captures for this funtion
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
|
@ -20,8 +20,10 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
||||
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||
}
|
||||
|
||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
Status Variable::CreateUninitialized(
|
||||
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name, const char* raw_device_name,
|
||||
const std::vector<std::string>& component_devices,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, raw_device_name, &handle));
|
||||
|
||||
if (component_devices.empty()) {
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, raw_device_name, &handle));
|
||||
output->reset(
|
||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
if (!tensorflow::isa<EagerContext>(ctx)) {
|
||||
return errors::InvalidArgument(
|
||||
"Can only load distributed variables with EagerContext.");
|
||||
}
|
||||
|
||||
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
|
||||
|
||||
std::vector<TensorHandle*> handles;
|
||||
for (const auto& device : component_devices) {
|
||||
ImmediateTensorHandlePtr handlePtr;
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
|
||||
&handlePtr));
|
||||
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
|
||||
return errors::Internal("Returned replica handle has unsupported type.");
|
||||
}
|
||||
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
|
||||
}
|
||||
TensorHandle* packed_handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||
std::move(handles), eager_ctx, &packed_handle));
|
||||
// The call to `CreatePackedHandle` incremented the handles' reference count,
|
||||
// which we must now decrement to make the packed handle the owner of those
|
||||
// handles. We can't loop through the `handles` vector because it was
|
||||
// `std::move`d in the call above.
|
||||
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
|
||||
TensorHandle* component;
|
||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
|
||||
component->Unref();
|
||||
}
|
||||
|
||||
handle.reset(packed_handle);
|
||||
output->reset(
|
||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||
return Status();
|
||||
|
@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible {
|
||||
public:
|
||||
// Creates an uninitialized resource variable. Note that a caller must
|
||||
// call "assign" to associate a value with the variable.
|
||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<Variable>* output);
|
||||
static Status CreateUninitialized(
|
||||
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name, const char* raw_device_name,
|
||||
const std::vector<std::string>& component_devices,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
// The dtype of the underlying variable.
|
||||
DataType dtype();
|
||||
|
@ -17,14 +17,22 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -40,6 +48,83 @@ namespace {
|
||||
using StructuredValueDictEntry =
|
||||
protobuf::MapPair<std::string, StructuredValue>;
|
||||
|
||||
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
|
||||
// Graphdef
|
||||
using NodeAttrMap =
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher>;
|
||||
|
||||
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
|
||||
using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>;
|
||||
|
||||
// Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and
|
||||
// returns a tensorflow::Constant.
|
||||
Status ConstantFromSavedConstant(
|
||||
ImmediateExecutionContext* ctx,
|
||||
const tensorflow::SavedConstant& saved_constant,
|
||||
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
|
||||
const std::string& const_op_name = saved_constant.operation();
|
||||
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
|
||||
if (node_name_and_attrs == node_attr_map.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Unable to find Const operation with name'", const_op_name,
|
||||
"' in SavedModel graphdef");
|
||||
}
|
||||
const AttrValueMap* attrs = node_name_and_attrs->second;
|
||||
const auto& attr_name_and_value = attrs->find("value");
|
||||
if (attr_name_and_value == attrs->end()) {
|
||||
return errors::FailedPrecondition("Unable to find Const operation '",
|
||||
const_op_name, "'s value attribute");
|
||||
}
|
||||
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
|
||||
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
|
||||
}
|
||||
|
||||
// Finds the "signatures" object in the object graph, and fills a mapping of
|
||||
// each signature's name to the corresponding function's node in the object
|
||||
// graph.
|
||||
Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
|
||||
gtl::FlatMap<std::string, int>* signatures_map) {
|
||||
if (saved_objects.nodes().empty()) {
|
||||
return errors::FailedPrecondition("Saved Object Graph was empty.");
|
||||
}
|
||||
const SavedObject& root = saved_objects.nodes(0);
|
||||
const SavedObject* signatures = nullptr;
|
||||
for (const auto& child : root.children()) {
|
||||
if (child.local_name() == "signatures") {
|
||||
if (child.node_id() >= saved_objects.nodes().size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Signature object had child node id ", child.node_id(),
|
||||
" which exceeds the size of the set of nodes");
|
||||
}
|
||||
signatures = &saved_objects.nodes(child.node_id());
|
||||
}
|
||||
}
|
||||
|
||||
// Some basic sanity checks that this object is actually our "signatures" map
|
||||
if (signatures == nullptr) {
|
||||
// This is where the "signatures" attribute is always set:
|
||||
// https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109
|
||||
return errors::FailedPrecondition(
|
||||
"SavedObjectGraph's root object must have a child 'signatures' object");
|
||||
}
|
||||
if (signatures->kind_case() != SavedObject::kUserObject) {
|
||||
return errors::FailedPrecondition(
|
||||
"Signatures must be a SavedObject of type UserObject.");
|
||||
}
|
||||
if (signatures->user_object().identifier() != "signature_map") {
|
||||
// This is where the string comes from:
|
||||
// https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220
|
||||
return errors::FailedPrecondition(
|
||||
"Signatures SavedObject must have identifier 'signature_map'.");
|
||||
}
|
||||
|
||||
for (const auto& child : signatures->children()) {
|
||||
(*signatures_map)[child.local_name()] = child.node_id();
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Perform some basic sanity checks on SavedConcreteFunction's input and
|
||||
// output signatures with respect to the corresponding FunctionDef's input
|
||||
// and output args.
|
||||
@ -98,6 +183,21 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) {
|
||||
// We only allow loading functions that have an annotated input signature,
|
||||
// which means there is 1:1 correspondence between tf.function
|
||||
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
|
||||
// the same restriction that MLIR has:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
|
||||
if (saved_function.concrete_functions_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Only tf.functions annotated with an input signature are supported "
|
||||
"by SavedModelAPI. This means that there should only be a single "
|
||||
"ConcreteFunction per tf.function");
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
|
||||
@ -135,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const std::string& name = variable.name();
|
||||
tensorflow::TensorShape shape(variable.shape());
|
||||
tensorflow::DataType dtype = variable.dtype();
|
||||
std::vector<std::string> component_devices;
|
||||
|
||||
for (const auto& component :
|
||||
variable.experimental_distributed_variable_components()) {
|
||||
component_devices.push_back(component.device());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
||||
ctx, dtype, shape, name,
|
||||
variable.device().empty() ? nullptr : variable.device().c_str(), output));
|
||||
variable.device().empty() ? nullptr : variable.device().c_str(),
|
||||
component_devices, output));
|
||||
return Status();
|
||||
}
|
||||
|
||||
@ -224,17 +331,17 @@ Status FlattenSignature(const StructuredValue& signature,
|
||||
}
|
||||
}
|
||||
|
||||
const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph,
|
||||
int* node_id) {
|
||||
absl::optional<int> FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph) {
|
||||
const auto& nodes = object_graph.nodes();
|
||||
if (nodes.empty()) {
|
||||
return nullptr;
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Starting from the root, iterate through the saved object graph, matching
|
||||
// object names as we go.
|
||||
const SavedObject* current_node = &nodes.Get(0);
|
||||
int node_id = 0;
|
||||
const SavedObject* current_node = &nodes.Get(node_id);
|
||||
|
||||
for (absl::string_view object_name : absl::StrSplit(path, '.')) {
|
||||
auto child_node_iter = std::find_if(
|
||||
@ -244,32 +351,28 @@ const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
return object_name == obj.local_name();
|
||||
});
|
||||
if (child_node_iter == current_node->children().end()) {
|
||||
return nullptr;
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (node_id) {
|
||||
*node_id = child_node_iter->node_id();
|
||||
}
|
||||
current_node = &nodes.Get(child_node_iter->node_id());
|
||||
|
||||
node_id = child_node_iter->node_id();
|
||||
current_node = &nodes.Get(node_id);
|
||||
}
|
||||
|
||||
return current_node;
|
||||
return node_id;
|
||||
}
|
||||
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
|
||||
NodeToAttrMap(const tensorflow::GraphDef& graphdef) {
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
|
||||
result;
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
|
||||
const tensorflow::GraphDef& graphdef) {
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> result;
|
||||
for (const tensorflow::NodeDef& node : graphdef.node()) {
|
||||
result[node.name()] = &node.attr();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>
|
||||
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
|
||||
FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>
|
||||
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
|
||||
result;
|
||||
for (const FunctionDef& function_def : library.function()) {
|
||||
result[function_def.signature().name()] = &function_def;
|
||||
@ -277,5 +380,156 @@ FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
|
||||
return result;
|
||||
}
|
||||
|
||||
Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
|
||||
ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
PartiallyRevivedObjects* objects) {
|
||||
// This is needed to restore "Constant" nodes by looking up their
|
||||
// "Value" attribute.
|
||||
NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def());
|
||||
|
||||
// These are needed for creating "Assets", by looking up their filenames.
|
||||
std::vector<AssetFileDef> assets;
|
||||
TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets));
|
||||
|
||||
// Signatures are needed for determining whether a function is a
|
||||
// SignatureDefFunction or not.
|
||||
gtl::FlatMap<std::string, int> signatures_map;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetSignaturesMap(metagraph.object_graph_def(), &signatures_map));
|
||||
|
||||
gtl::FlatMap<int, std::string> reversed_signatures_map;
|
||||
reversed_signatures_map.reserve(signatures_map.size());
|
||||
for (const auto& signature_key_and_node : signatures_map) {
|
||||
reversed_signatures_map.emplace(signature_key_and_node.second,
|
||||
signature_key_and_node.first);
|
||||
}
|
||||
|
||||
// FunctionDefs are needed to help construct
|
||||
// TFConcreteFunction/SignatureDefFunctions
|
||||
const FunctionDefMap function_def_map =
|
||||
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
|
||||
|
||||
// Iterate through all the saved objects, restoring objects (if we can) as we
|
||||
// go. For objects that dependencies on other objects (resources/functions),
|
||||
// we partially initialize "builders" that correspond to their currently known
|
||||
// state, and gradually fill them out in subsequent passes.
|
||||
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||
const SavedObject& node = metagraph.object_graph_def().nodes(i);
|
||||
if (node.kind_case() == SavedObject::kVariable) {
|
||||
std::unique_ptr<Variable> variable;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LoadSavedVariable(context, node.variable(), &variable));
|
||||
objects->variables[i] = std::move(variable);
|
||||
} else if (node.kind_case() == SavedObject::kConstant) {
|
||||
std::unique_ptr<Constant> constant;
|
||||
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
|
||||
node_attr_map, &constant));
|
||||
objects->constants[i] = std::move(constant);
|
||||
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||
std::unique_ptr<Asset> asset;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LoadSavedAsset(context, node.asset(), directory, assets, &asset));
|
||||
objects->assets[i] = std::move(asset);
|
||||
} else if (node.kind_case() == SavedObject::kResource) {
|
||||
RestoredResourceRevivalState resource_revival_state;
|
||||
// We'll set the resource's functions in a subsequent pass, once we get
|
||||
// all functions in a partially revived state.
|
||||
resource_revival_state.device = node.resource().device();
|
||||
objects->restored_resources[i] = std::move(resource_revival_state);
|
||||
} else if (node.kind_case() == SavedObject::kFunction) {
|
||||
// Get the SavedFunction node and validate it has a single concrete func.
|
||||
const SavedFunction& saved_function = node.function();
|
||||
TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function));
|
||||
|
||||
// Retrieve related function information.
|
||||
const std::string& function_name = saved_function.concrete_functions(0);
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
const SavedConcreteFunction& saved_concrete_func =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
const FunctionSpec& function_spec = saved_function.function_spec();
|
||||
|
||||
// Construct either a SignatureDefFunctionBuilder or a
|
||||
// ConcreteFunctionBuilder, depending on whether this node was a child
|
||||
// of the "signatures" attribute from root object.
|
||||
auto reverse_signature_iter = reversed_signatures_map.find(i);
|
||||
if (reverse_signature_iter != reversed_signatures_map.end()) {
|
||||
TFSignatureDefFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
func_revival_state.signature_key = reverse_signature_iter->second;
|
||||
objects->signature_def_functions[i] = std::move(func_revival_state);
|
||||
} else {
|
||||
TFConcreteFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
func_revival_state.function_spec = &function_spec;
|
||||
objects->concrete_functions[i] = std::move(func_revival_state);
|
||||
}
|
||||
} else if (node.kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function();
|
||||
|
||||
// Retrieve related function information.
|
||||
const std::string& function_name = bare_cf.concrete_function_name();
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
const SavedConcreteFunction& saved_concrete_func =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
// Check whether this is a SignatureDefFunction, or not.
|
||||
auto reverse_signature_iter = reversed_signatures_map.find(i);
|
||||
if (reverse_signature_iter != reversed_signatures_map.end()) {
|
||||
TFSignatureDefFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
func_revival_state.signature_key = reverse_signature_iter->second;
|
||||
objects->signature_def_functions[i] = std::move(func_revival_state);
|
||||
} else {
|
||||
TFConcreteFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
objects->concrete_functions[i] = std::move(func_revival_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've partially restored all functions, we can have resources
|
||||
// point to them
|
||||
for (auto& node_and_resource_revival_state : objects->restored_resources) {
|
||||
int node_id = node_and_resource_revival_state.first;
|
||||
const SavedObjectGraph& obj_graph = metagraph.object_graph_def();
|
||||
const SavedObject& node = obj_graph.nodes(node_id);
|
||||
RestoredResourceRevivalState& resource =
|
||||
node_and_resource_revival_state.second;
|
||||
for (const TrackableObjectGraph::TrackableObject::ObjectReference& child :
|
||||
node.children()) {
|
||||
int child_node_id = child.node_id();
|
||||
// Note(bmzhao): The expected functions saved by a resource object are:
|
||||
// "_create_resource", "_initialize", and "_destroy_resource".
|
||||
// https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281
|
||||
if (child.local_name() == "_create_resource" &&
|
||||
obj_graph.nodes(child.node_id()).kind_case() ==
|
||||
SavedObject::kFunction) {
|
||||
resource.create_resource = &objects->concrete_functions[child_node_id];
|
||||
} else if (child.local_name() == "_initialize" &&
|
||||
obj_graph.nodes(child.node_id()).kind_case() ==
|
||||
SavedObject::kFunction) {
|
||||
resource.initialize = &objects->concrete_functions[child_node_id];
|
||||
} else if (child.local_name() == "_destroy_resource" &&
|
||||
obj_graph.nodes(child.node_id()).kind_case() ==
|
||||
SavedObject::kFunction) {
|
||||
resource.destroy_resource = &objects->concrete_functions[child_node_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
objects->signatures_map = std::move(signatures_map);
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
@ -22,14 +22,17 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
@ -75,26 +78,30 @@ Status LoadTFConcreteFunction(
|
||||
Status FlattenSignature(const StructuredValue& signature,
|
||||
std::vector<const TensorSpecProto*>* flattened_specs);
|
||||
|
||||
// Find the SavedObject in `object_graph` at location `path`. `path` must be
|
||||
// Find the node id in `object_graph` at location `path`. `path` must be
|
||||
// a dot-delimited string of object names relative to the root object. If no
|
||||
// object is found, returns nullptr. Callers must ensure `object_graph`
|
||||
// outlives the returned pointer. If not `nullptr`, `node_id` will contain the
|
||||
// index of the returned object in the `SavedObjectGraph.nodes` array.
|
||||
const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph,
|
||||
int* node_id = nullptr);
|
||||
// object is found, returns absl::nullopt.
|
||||
absl::optional<int> FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph);
|
||||
|
||||
// Maps each node in `graphdef` to its corresponding Attribute Map.
|
||||
// Callers must ensure that `graphdef` outlives the returned map.
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
|
||||
NodeToAttrMap(const tensorflow::GraphDef& graphdef);
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
|
||||
const tensorflow::GraphDef& graphdef);
|
||||
|
||||
// Maps the name of each FunctionDef in `library` to its corresponding
|
||||
// FunctionDef. Callers must ensure `library` outlives the returned map.
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>
|
||||
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
|
||||
FunctionNameToFunctionDefMap(const FunctionDefLibrary& library);
|
||||
|
||||
// Walks through the SavedObjectGraph in metagraph, and restores all nodes
|
||||
// (except "UserDefinedObjects") with their corresponding type in
|
||||
// "PartiallyRevivedObjects".
|
||||
Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
|
||||
ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
PartiallyRevivedObjects* objects);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
||||
Status status;
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||
absl::nullopt, nullptr, &var));
|
||||
absl::nullopt, nullptr, {}, &var));
|
||||
|
||||
// Create a TensorHandle
|
||||
ImmediateTensorHandlePtr expected_handle =
|
||||
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec)
|
||||
: name_(std::move(name)), spec_(std::move(spec)) {}
|
||||
|
||||
const std::string& SignatureDefParam::name() const { return name_; }
|
||||
|
||||
const TensorSpec& SignatureDefParam::spec() const { return spec_; }
|
||||
|
||||
SignatureDefFunctionMetadata::SignatureDefFunctionMetadata(
|
||||
std::vector<SignatureDefParam> arguments,
|
||||
std::vector<SignatureDefParam> returns)
|
||||
: arguments_(std::move(arguments)), returns_(std::move(returns)) {}
|
||||
|
||||
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::arguments()
|
||||
const {
|
||||
return arguments_;
|
||||
}
|
||||
|
||||
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::returns()
|
||||
const {
|
||||
return returns_;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -16,10 +16,42 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// SignatureDefParam represents a named Tensor input or output to a
|
||||
// SignatureDefFunction.
|
||||
class SignatureDefParam {
|
||||
public:
|
||||
SignatureDefParam(std::string name, TensorSpec spec);
|
||||
|
||||
const std::string& name() const;
|
||||
|
||||
const TensorSpec& spec() const;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
TensorSpec spec_;
|
||||
};
|
||||
|
||||
class SignatureDefFunctionMetadata {
|
||||
// TODO(bmzhao): Fill in with fields as necessary
|
||||
public:
|
||||
SignatureDefFunctionMetadata() = default;
|
||||
SignatureDefFunctionMetadata(std::vector<SignatureDefParam> arguments,
|
||||
std::vector<SignatureDefParam> returns);
|
||||
|
||||
const std::vector<SignatureDefParam>& arguments() const;
|
||||
const std::vector<SignatureDefParam>& returns() const;
|
||||
|
||||
private:
|
||||
std::vector<SignatureDefParam> arguments_;
|
||||
std::vector<SignatureDefParam> returns_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TensorSpec::TensorSpec()
|
||||
: shape_(std::initializer_list<int64>()), dtype_(DT_FLOAT) {}
|
||||
|
||||
TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype)
|
||||
: shape_(std::move(shape)), dtype_(dtype) {}
|
||||
|
||||
TensorSpec::TensorSpec(const TensorSpecProto& proto)
|
||||
: shape_(proto.shape()), dtype_(proto.dtype()) {}
|
||||
|
||||
const PartialTensorShape& TensorSpec::shape() const { return shape_; }
|
||||
|
||||
DataType TensorSpec::dtype() const { return dtype_; }
|
||||
|
||||
} // namespace tensorflow
|
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note(bmzhao): TensorSpec deliberately does not store the "name" from a
|
||||
// TensorSpecProto. From edloper@, "Names should really be associated with
|
||||
// parameters, not the tensors inside those parameters. This would be
|
||||
// inconsistent with the corresponding Python class, but I don't think that's
|
||||
// necessarily a problem. If it turns out later that we really need a name
|
||||
// attribute here, we can always add it back in; but let's see how far we can
|
||||
// get without it."
|
||||
class TensorSpec {
|
||||
public:
|
||||
// Constructs a scalar, DT_FLOAT TensorSpec
|
||||
TensorSpec();
|
||||
|
||||
TensorSpec(PartialTensorShape shape, DataType dtype);
|
||||
|
||||
explicit TensorSpec(const TensorSpecProto& proto);
|
||||
|
||||
const PartialTensorShape& shape() const;
|
||||
DataType dtype() const;
|
||||
|
||||
private:
|
||||
PartialTensorShape shape_;
|
||||
DataType dtype_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
@ -48,7 +48,6 @@ EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr));
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
@ -30,6 +29,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
@ -37,7 +39,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -46,6 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -62,142 +64,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
|
||||
using FunctionDefMap =
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>;
|
||||
|
||||
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
|
||||
// Graphdef
|
||||
using NodeAttrMap =
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
|
||||
|
||||
// Maps from Node ID to an "Revived Object" implementing
|
||||
// "TensorHandleConvertible"
|
||||
using RevivedObjectMap =
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
|
||||
using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>;
|
||||
|
||||
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
|
||||
using ConcreteFunctionMap =
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
|
||||
using FlatTensorFunctionMap =
|
||||
gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>;
|
||||
|
||||
namespace {
|
||||
|
||||
Status ConstantFromSavedConstant(
|
||||
ImmediateExecutionContext* ctx,
|
||||
const tensorflow::SavedConstant& saved_constant,
|
||||
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
|
||||
const std::string& const_op_name = saved_constant.operation();
|
||||
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
|
||||
if (node_name_and_attrs == node_attr_map.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Unable to find Const operation with name'", const_op_name,
|
||||
"' in SavedModel graphdef");
|
||||
}
|
||||
const AttrValueMap* attrs = node_name_and_attrs->second;
|
||||
const auto& attr_name_and_value = attrs->find("value");
|
||||
if (attr_name_and_value == attrs->end()) {
|
||||
return errors::FailedPrecondition("Unable to find Const operation '",
|
||||
const_op_name, "'s value attribute");
|
||||
}
|
||||
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
|
||||
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
|
||||
}
|
||||
|
||||
// Restores all non-function objects in the SavedModel's object graph.
|
||||
// This function walks through the metagraph's saved object graph, and
|
||||
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
|
||||
// SavedResources. These are returned via the `out` parameter.
|
||||
Status ReviveObjects(
|
||||
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
|
||||
revived_objects) {
|
||||
// This is needed to restore "Constant" nodes by looking up their
|
||||
// "Value" attribute.
|
||||
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
|
||||
|
||||
// These are needed for creating "Assets", by looking up their filenames.
|
||||
std::vector<AssetFileDef> assets;
|
||||
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets));
|
||||
|
||||
// Iterate through all the saved objects, restoring objects as we go.
|
||||
// We don't recreate functions until all other objects have been created.
|
||||
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||
const SavedObject& node = metagraph.object_graph_def().nodes(i);
|
||||
if (node.kind_case() == SavedObject::kVariable) {
|
||||
std::unique_ptr<Variable> variable;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::LoadSavedVariable(context, node.variable(), &variable));
|
||||
(*revived_objects)[i] = std::move(variable);
|
||||
} else if (node.kind_case() == SavedObject::kConstant) {
|
||||
std::unique_ptr<Constant> constant;
|
||||
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
|
||||
node_attr_map, &constant));
|
||||
(*revived_objects)[i] = std::move(constant);
|
||||
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||
std::unique_ptr<Asset> asset;
|
||||
TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(),
|
||||
directory, assets, &asset));
|
||||
(*revived_objects)[i] = std::move(asset);
|
||||
} else if (node.kind_case() == SavedObject::kResource) {
|
||||
// TODO(bmzhao): Figure out how resource loading works and implement it
|
||||
return errors::Unimplemented(
|
||||
"SavedResource loading is not implemented yet");
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReviveFunctions(const MetaGraphDef& metagraph,
|
||||
const RevivedObjectMap& revived_objects,
|
||||
ImmediateExecutionContext* context,
|
||||
ConcreteFunctionMap* restored_functions) {
|
||||
const FunctionDefMap function_def_map =
|
||||
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
|
||||
|
||||
// Iterate through all objects, only examining functions.
|
||||
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
|
||||
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
const std::string& function_name =
|
||||
node.bare_concrete_function().concrete_function_name();
|
||||
|
||||
const SavedConcreteFunction& saved_concrete_function =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||
saved_concrete_function, function_def, revived_objects, context,
|
||||
&concrete_function));
|
||||
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||
} else if (node.kind_case() == SavedObject::kFunction) {
|
||||
// We only allow loading functions that have an annotated input signature,
|
||||
// which means there is 1:1 correspondence between tf.function
|
||||
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
|
||||
// the same restriction that MLIR has:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
|
||||
const SavedFunction& saved_function = node.function();
|
||||
if (saved_function.concrete_functions_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Only tf.functions annotated with an input signature are supported "
|
||||
"by SavedModelAPI. This means that there should only be a single "
|
||||
"ConcreteFunction per tf.function");
|
||||
}
|
||||
const std::string& function_name = saved_function.concrete_functions(0);
|
||||
const SavedConcreteFunction& saved_concrete_function =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||
saved_concrete_function, function_def, revived_objects, context,
|
||||
&concrete_function));
|
||||
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
FindSerializedTensorInTrackable(
|
||||
@ -234,7 +109,7 @@ FindSerializedTensorInTrackable(
|
||||
// overridden "restore" method:
|
||||
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
|
||||
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
const RevivedObjectMap& revived_objects,
|
||||
const RevivedObjects& revived_objects,
|
||||
const std::string& directory,
|
||||
ImmediateExecutionContext* context) {
|
||||
// TODO(bmzhao): Batch up all the restores into a single restore op per
|
||||
@ -254,8 +129,7 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Variable* variable =
|
||||
down_cast<Variable*>(revived_objects.at(node).get());
|
||||
Variable* variable = revived_objects.variables.at(node).get();
|
||||
|
||||
// Restore the tensor's value from the checkpoint
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
@ -289,43 +163,58 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status InitializeAllResources(const RevivedObjects& revived) {
|
||||
for (const auto& node_and_resource : revived.restored_resources) {
|
||||
const RestoredResource& resource = node_and_resource.second;
|
||||
TF_RETURN_IF_ERROR(resource.Initialize());
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) {
|
||||
const SavedObject* object =
|
||||
absl::optional<int> node =
|
||||
internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
|
||||
if (object == nullptr) {
|
||||
if (!node.has_value()) {
|
||||
return errors::NotFound("No saved object found at path ", function_path);
|
||||
}
|
||||
|
||||
if (object->kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
*function =
|
||||
concrete_functions_
|
||||
.at(object->bare_concrete_function().concrete_function_name())
|
||||
.get();
|
||||
} else if (object->kind_case() == SavedObject::kFunction) {
|
||||
*function =
|
||||
concrete_functions_.at(object->function().concrete_functions(0)).get();
|
||||
} else {
|
||||
return errors::InvalidArgument(function_path,
|
||||
" is not a path to a Function.");
|
||||
auto function_iter = revived_objects_.concrete_functions.find(*node);
|
||||
if (function_iter == revived_objects_.concrete_functions.end()) {
|
||||
return errors::NotFound("No function found at path ", function_path);
|
||||
}
|
||||
|
||||
*function = function_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||
const std::string& signature_def_key, SignatureDefFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving SignatureDef functions is unimplemented currently");
|
||||
auto signatures_iter =
|
||||
revived_objects_.signatures_map.find(signature_def_key);
|
||||
if (signatures_iter == revived_objects_.signatures_map.end()) {
|
||||
return errors::NotFound("No signature with key ", signature_def_key,
|
||||
" was found");
|
||||
}
|
||||
int node = signatures_iter->second;
|
||||
|
||||
auto function_iter = revived_objects_.signature_def_functions.find(node);
|
||||
if (function_iter == revived_objects_.signature_def_functions.end()) {
|
||||
return errors::Internal(
|
||||
"Unable to find SignatureDefFunction associated with key ",
|
||||
signature_def_key, " despite key being valid.");
|
||||
}
|
||||
|
||||
*function = function_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(concrete_functions_.size());
|
||||
for (auto& index_and_function : concrete_functions_) {
|
||||
result.reserve(revived_objects_.concrete_functions.size());
|
||||
for (auto& index_and_function : revived_objects_.concrete_functions) {
|
||||
result.push_back(index_and_function.second.get());
|
||||
}
|
||||
return result;
|
||||
@ -333,39 +222,27 @@ std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
|
||||
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
|
||||
Variable** variable) {
|
||||
int node_id;
|
||||
const SavedObject* object = internal::FindNodeAtPath(
|
||||
variable_path, bundle_.saved_object_graph(), &node_id);
|
||||
if (object == nullptr) {
|
||||
absl::optional<int> node =
|
||||
internal::FindNodeAtPath(variable_path, bundle_.saved_object_graph());
|
||||
if (!node.has_value()) {
|
||||
return errors::NotFound("No saved object found at path ", variable_path);
|
||||
}
|
||||
|
||||
if (object->kind_case() == SavedObject::kVariable) {
|
||||
auto iter = revived_objects_.find(node_id);
|
||||
if (iter == revived_objects_.end()) {
|
||||
return errors::Internal("Variable ", variable_path,
|
||||
" was not properly revived.");
|
||||
}
|
||||
*variable = static_cast<Variable*>(iter->second.get());
|
||||
return Status();
|
||||
auto variables_iter = revived_objects_.variables.find(*node);
|
||||
if (variables_iter == revived_objects_.variables.end()) {
|
||||
return errors::NotFound("No variable found at path ", variable_path);
|
||||
}
|
||||
|
||||
*variable = nullptr;
|
||||
return errors::InvalidArgument(
|
||||
variable_path, " is not a path to a Variable (kind=", object->kind_case(),
|
||||
")");
|
||||
*variable = variables_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
TFSavedModelAPI::TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects,
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions)
|
||||
TFSavedModelAPI::TFSavedModelAPI(const std::string& directory,
|
||||
SavedModelV2Bundle bundle,
|
||||
RevivedObjects revived_objects)
|
||||
: directory_(directory),
|
||||
bundle_(std::move(bundle)),
|
||||
revived_objects_(std::move(revived_objects)),
|
||||
concrete_functions_(std::move(concrete_functions)) {}
|
||||
revived_objects_(std::move(revived_objects)) {}
|
||||
|
||||
Status TFSavedModelAPI::Load(
|
||||
const std::string& directory,
|
||||
@ -386,28 +263,25 @@ Status TFSavedModelAPI::Load(
|
||||
// This occurs in python here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
|
||||
|
||||
RevivedObjectMap revived_objects;
|
||||
TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory,
|
||||
&revived_objects));
|
||||
// Step 1: For each node in the graph, we should initialize an object of the
|
||||
// corresponding type. For objects that depend on the initialization of other
|
||||
// objects (like functions which capture resources), we will initialize them
|
||||
// in step 2.
|
||||
PartiallyRevivedObjects partially_revived_objects;
|
||||
TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
|
||||
bundle.meta_graph_def(), context, directory, &partially_revived_objects));
|
||||
|
||||
// TODO(bmzhao): When we later add support for loading resources, we need to
|
||||
// handle the case where materializing a function's captures requires invoking
|
||||
// other functions. This occurs when retrieving the resource handle for a
|
||||
// TrackableResource:
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
|
||||
// This requires restoring functions in a topological sort order by capture
|
||||
// dependencies.
|
||||
ConcreteFunctionMap function_map;
|
||||
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
|
||||
context, &function_map));
|
||||
RevivedObjects revived_objects;
|
||||
TF_RETURN_IF_ERROR(partially_revived_objects.Build(
|
||||
context, bundle.saved_object_graph(), &revived_objects));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreCheckpoint(&bundle, revived_objects, directory, context));
|
||||
|
||||
TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects));
|
||||
|
||||
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
|
||||
std::move(revived_objects),
|
||||
std::move(function_map)));
|
||||
std::move(revived_objects)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
@ -72,19 +73,12 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
Status GetVariable(const std::string& variable_path, Variable** variable);
|
||||
|
||||
private:
|
||||
TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects,
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions);
|
||||
TFSavedModelAPI(const std::string& directory, SavedModelV2Bundle bundle,
|
||||
RevivedObjects revived_objects);
|
||||
|
||||
std::string directory_;
|
||||
SavedModelV2Bundle bundle_;
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects_;
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions_;
|
||||
RevivedObjects revived_objects_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -224,6 +224,8 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":signature_def_function_metadata_type",
|
||||
":signature_def_param_list",
|
||||
":signature_def_param_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
],
|
||||
@ -240,6 +242,104 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_param",
|
||||
srcs = [
|
||||
"signature_def_param.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_param.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":signature_def_param_type",
|
||||
":tensor_spec",
|
||||
":tensor_spec_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_shape_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_param_type",
|
||||
hdrs = [
|
||||
"signature_def_param_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_param_list",
|
||||
srcs = [
|
||||
"signature_def_param_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":signature_def_param",
|
||||
":signature_def_param_list_type",
|
||||
":signature_def_param_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_param_list_type",
|
||||
hdrs = [
|
||||
"signature_def_param_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_spec",
|
||||
srcs = [
|
||||
"tensor_spec.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:tensor_spec.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":tensor_spec_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_shape",
|
||||
"//tensorflow/c:tf_shape_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_spec_type",
|
||||
hdrs = [
|
||||
"tensor_spec_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_shape_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
@ -252,6 +352,8 @@ tf_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":saved_model_api_type",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_shape",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -260,6 +362,11 @@ tf_cc_test(
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata",
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_param",
|
||||
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list",
|
||||
"//tensorflow/c/experimental/saved_model/public:tensor_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -24,9 +24,17 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_shape.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
@ -142,6 +150,146 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
// This tests running the "serving_default" SignatureDefFunction from the
|
||||
// VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
|
||||
// protobuf in the metagraph looks like:
|
||||
// signature_def: {
|
||||
// key : "serving_default"
|
||||
// value: {
|
||||
// inputs: {
|
||||
// key : "a"
|
||||
// value: {
|
||||
// name : "serving_default_a:0"
|
||||
// dtype: DT_FLOAT
|
||||
// tensor_shape: {
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// inputs: {
|
||||
// key : "b"
|
||||
// value: {
|
||||
// name : "serving_default_b:0"
|
||||
// dtype: DT_FLOAT
|
||||
// tensor_shape: {
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// outputs: {
|
||||
// key : "output_0"
|
||||
// value: {
|
||||
// name : "StatefulPartitionedCall:0"
|
||||
// dtype: DT_FLOAT
|
||||
// tensor_shape: {
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// method_name: "tensorflow/serving/predict"
|
||||
// }
|
||||
// }
|
||||
TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TF_SignatureDefFunction* serving_default =
|
||||
TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
|
||||
status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_SignatureDefFunctionMetadata* metadata =
|
||||
TF_SignatureDefFunctionGetMetadata(serving_default);
|
||||
|
||||
const TF_SignatureDefParamList* args =
|
||||
TF_SignatureDefFunctionMetadataArgs(metadata);
|
||||
const TF_SignatureDefParamList* returns =
|
||||
TF_SignatureDefFunctionMetadataReturns(metadata);
|
||||
|
||||
EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
|
||||
const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
|
||||
const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
|
||||
const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
|
||||
|
||||
// Input "a" is a scalar, float32 tensor
|
||||
EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
|
||||
EXPECT_EQ(0, TF_ShapeDims(shape_a));
|
||||
|
||||
const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
|
||||
const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
|
||||
const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
|
||||
|
||||
// Input "b" is a scalar, float32 tensor
|
||||
EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
|
||||
EXPECT_EQ(0, TF_ShapeDims(shape_b));
|
||||
|
||||
EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
|
||||
|
||||
const TF_SignatureDefParam* param_out =
|
||||
TF_SignatureDefParamListGet(returns, 0);
|
||||
const TF_TensorSpec* tensor_spec_out =
|
||||
TF_SignatureDefParamTensorSpec(param_out);
|
||||
const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
|
||||
|
||||
// Output "output_0" is a scalar, float32 tensor
|
||||
EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
|
||||
EXPECT_EQ(0, TF_ShapeDims(shape_out));
|
||||
|
||||
std::vector<TFE_TensorHandle*> compute_fn_inputs;
|
||||
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
|
||||
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
|
||||
compute_fn_inputs.push_back(input_a);
|
||||
compute_fn_inputs.push_back(input_b);
|
||||
|
||||
TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
|
||||
serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
std::vector<TFE_TensorHandle*> compute_fn_outputs(
|
||||
TF_SignatureDefParamListSize(returns));
|
||||
int num_retvals = TF_SignatureDefParamListSize(returns);
|
||||
|
||||
TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
|
||||
status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
float output_value = *static_cast<float*>(TF_TensorData(result));
|
||||
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
|
||||
EXPECT_FLOAT_EQ(output_value, 8.0);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
|
||||
TFE_DeleteTensorHandle(input_a);
|
||||
TFE_DeleteTensorHandle(input_b);
|
||||
TFE_DeleteOp(serving_default_op);
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -186,7 +334,8 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::tstring* output_value =
|
||||
static_cast<tensorflow::tstring*>(TF_TensorData(result));
|
||||
EXPECT_EQ(std::string(*output_value), "TEST ASSET FILE CONTENTS\n");
|
||||
std::string file_contents(*output_value);
|
||||
EXPECT_NE(file_contents.find("TEST ASSET FILE CONTENTS"), std::string::npos);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
|
||||
@ -196,6 +345,142 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("StaticHashTableModule");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TF_ConcreteFunction* lookup_fn =
|
||||
TF_GetSavedModelConcreteFunction(saved_model, "lookup", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following
|
||||
// mapping:
|
||||
// "foo" -> 0
|
||||
// "bar" -> 1
|
||||
// "baz" -> 2
|
||||
// "wombat" -> 3
|
||||
// all other strings -> -1
|
||||
|
||||
// Call lookup function with input "foo", expecting an output of 0
|
||||
{
|
||||
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
|
||||
TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo"));
|
||||
lookup_fn_inputs.push_back(input_foo);
|
||||
|
||||
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
|
||||
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::int64* output_value =
|
||||
static_cast<tensorflow::int64*>(TF_TensorData(result));
|
||||
EXPECT_EQ(*output_value, 0);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(input_foo);
|
||||
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
|
||||
TFE_DeleteOp(lookup_op);
|
||||
}
|
||||
|
||||
// Call lookup function with input "baz", expecting an output of 2
|
||||
{
|
||||
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
|
||||
TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz"));
|
||||
lookup_fn_inputs.push_back(input_foo);
|
||||
|
||||
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
|
||||
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::int64* output_value =
|
||||
static_cast<tensorflow::int64*>(TF_TensorData(result));
|
||||
EXPECT_EQ(*output_value, 2);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(input_foo);
|
||||
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
|
||||
TFE_DeleteOp(lookup_op);
|
||||
}
|
||||
|
||||
// Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1
|
||||
{
|
||||
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
|
||||
TFE_TensorHandle* input_foo =
|
||||
TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY"));
|
||||
lookup_fn_inputs.push_back(input_foo);
|
||||
|
||||
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
|
||||
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::int64* output_value =
|
||||
static_cast<tensorflow::int64*>(TF_TensorData(result));
|
||||
EXPECT_EQ(*output_value, -1);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(input_foo);
|
||||
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
|
||||
TFE_DeleteOp(lookup_op);
|
||||
}
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
|
@ -16,5 +16,18 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
|
||||
|
||||
// TODO(bmzhao): Add getter functions here as necessary.
|
||||
extern "C" {
|
||||
|
||||
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataArgs(
|
||||
const TF_SignatureDefFunctionMetadata* list) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(list)->arguments());
|
||||
}
|
||||
|
||||
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataReturns(
|
||||
const TF_SignatureDefFunctionMetadata* list) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(list)->returns());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
extern const char* TF_SignatureDefParamName(const TF_SignatureDefParam* param) {
|
||||
return tensorflow::unwrap(param)->name().c_str();
|
||||
}
|
||||
|
||||
extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
|
||||
const TF_SignatureDefParam* param) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(param)->spec());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
extern size_t TF_SignatureDefParamListSize(
|
||||
const TF_SignatureDefParamList* list) {
|
||||
return tensorflow::unwrap(list)->size();
|
||||
}
|
||||
|
||||
extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
|
||||
const TF_SignatureDefParamList* list, int i) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(list)->at(i));
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
|
||||
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(std::vector<SignatureDefParam>,
|
||||
TF_SignatureDefParamList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
|
||||
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
|
||||
#include "tensorflow/c/tf_shape_internal.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_DataType TF_TensorSpecDataType(const TF_TensorSpec* spec) {
|
||||
return static_cast<TF_DataType>(tensorflow::unwrap(spec)->dtype());
|
||||
}
|
||||
|
||||
const TF_Shape* TF_TensorSpecShape(const TF_TensorSpec* spec) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(spec)->shape());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
|
||||
typedef struct TF_TensorSpec TF_TensorSpec;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
@ -1,3 +1,4 @@
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow:tensorflow.bzl", "py_strict_binary")
|
||||
|
||||
package(
|
||||
|
@ -28,6 +28,9 @@ exports_files(
|
||||
"saved_model_api.h",
|
||||
"signature_def_function.h",
|
||||
"signature_def_function_metadata.h",
|
||||
"signature_def_param.h",
|
||||
"signature_def_param_list.h",
|
||||
"tensor_spec.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
@ -45,6 +48,9 @@ cc_library(
|
||||
":saved_model_api",
|
||||
":signature_def_function",
|
||||
":signature_def_function_metadata",
|
||||
":signature_def_param",
|
||||
":signature_def_param_list",
|
||||
":tensor_spec",
|
||||
],
|
||||
)
|
||||
|
||||
@ -77,3 +83,18 @@ alias(
|
||||
name = "signature_def_function_metadata",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "signature_def_param",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "signature_def_param_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param_list",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensor_spec",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:tensor_spec",
|
||||
)
|
||||
|
@ -23,6 +23,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
@ -24,6 +27,18 @@ extern "C" {
|
||||
// SavedModel.
|
||||
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
|
||||
|
||||
// Retrieves the arguments of the SignatureDefFunction. The caller is not
|
||||
// responsible for freeing the returned pointer.
|
||||
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
|
||||
TF_SignatureDefFunctionMetadataArgs(
|
||||
const TF_SignatureDefFunctionMetadata* list);
|
||||
|
||||
// Retrieves the returns of the SignatureDefFunction. The caller is not
|
||||
// responsible for freeing the returned pointer.
|
||||
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
|
||||
TF_SignatureDefFunctionMetadataReturns(
|
||||
const TF_SignatureDefFunctionMetadata* list);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that containing metadata of an input/output of a
|
||||
// TF_SignatureDefFunction loaded from a SavedModel.
|
||||
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
|
||||
|
||||
// Returns the name of the given parameter. The caller is not responsible for
|
||||
// freeing the returned char*.
|
||||
TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName(
|
||||
const TF_SignatureDefParam* param);
|
||||
|
||||
// Returns the TensorSpec associated with the given parameter. The caller is
|
||||
// not reponsible for freeing the returned TF_TensorSpec*.
|
||||
TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
|
||||
const TF_SignatureDefParam* param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that containing metadata of an input/output of a
|
||||
// ConcreteFunction loaded from a SavedModel.
|
||||
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize(
|
||||
const TF_SignatureDefParamList* list);
|
||||
|
||||
// Returns the `i`th TF_SignatureDefParam in the list.
|
||||
TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
|
||||
const TF_SignatureDefParamList* list, int i);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
46
tensorflow/c/experimental/saved_model/public/tensor_spec.h
Normal file
46
tensorflow/c/experimental/saved_model/public/tensor_spec.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_shape.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type corresponding to TensorSpec
|
||||
typedef struct TF_TensorSpec TF_TensorSpec;
|
||||
|
||||
// Returns the dtype associated with the TensorSpec.
|
||||
TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType(
|
||||
const TF_TensorSpec* spec);
|
||||
|
||||
// Returns the shape associated with the TensorSpec. The returned Shape is not
|
||||
// owned by the caller. Caller must not call TF_DeleteShape on the returned
|
||||
// shape.
|
||||
TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape(
|
||||
const TF_TensorSpec* spec);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
@ -11,17 +11,29 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor_hdrs",
|
||||
hdrs = ["stream_executor.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor",
|
||||
srcs = ["stream_executor.cc"],
|
||||
hdrs = ["stream_executor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":stream_executor_internal",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:regexp",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/stream_executor:executor_cache",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
|
@ -28,7 +28,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
@ -40,6 +43,8 @@ limitations under the License.
|
||||
using tensorflow::StatusFromTF_Status;
|
||||
|
||||
namespace stream_executor {
|
||||
using tensorflow::StringPiece;
|
||||
|
||||
namespace {
|
||||
|
||||
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
|
||||
@ -59,10 +64,35 @@ namespace {
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
port::Status ValidateDeviceType(StringPiece type) {
|
||||
// Validate device type. Device type must start with a capital letter and
|
||||
// consist of capital letters and underscores. Reasoning behind this decision:
|
||||
// * At the minimum we want to disallow '/' and ':' since
|
||||
// these characters are used in device spec, for e.g.
|
||||
// /job:foo/replica:12/device:GPU:1.
|
||||
// * Underscores seem useful, for e.g. XLA_GPU uses underscores.
|
||||
// * Allowing lowercase might get confusing. For example, say someone
|
||||
// registers a new type called "Gpu". It might be confusing for users that
|
||||
// "Gpu" is not the same device type as "GPU".
|
||||
// Note that lowercase "cpu" and "gpu" are currently supported only for
|
||||
// legacy reasons:
|
||||
// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
|
||||
static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
|
||||
bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
|
||||
if (!matches) {
|
||||
return port::FailedPreconditionError(
|
||||
tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
|
||||
kTfDeviceTypeRegEx->pattern(), "."));
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPPlatform(const SP_Platform& platform) {
|
||||
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, name);
|
||||
VALIDATE_MEMBER(SP_Platform, platform, type);
|
||||
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
|
||||
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
|
||||
// `visible_device_count` could be 0 at initialization time.
|
||||
return port::Status::OK();
|
||||
}
|
||||
@ -76,6 +106,8 @@ port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
|
||||
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor);
|
||||
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns);
|
||||
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns);
|
||||
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device_fns);
|
||||
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device_fns);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
@ -104,6 +136,12 @@ port::Status ValidateSPDevice(const SP_Device& device) {
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
|
||||
VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
|
||||
// All other fields could theoretically be zero/null.
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
|
||||
const SP_Platform& platform) {
|
||||
VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
|
||||
@ -311,11 +349,13 @@ void HostCallbackTrampoline(void* ctx, TF_Status* status) {
|
||||
|
||||
class CStreamExecutor : public internal::StreamExecutorInterface {
|
||||
public:
|
||||
explicit CStreamExecutor(SP_Device device, SP_StreamExecutor* stream_executor,
|
||||
explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
|
||||
SP_StreamExecutor* stream_executor,
|
||||
SP_Platform* platform, SP_PlatformFns* platform_fns,
|
||||
SP_TimerFns* timer_fns, const std::string& name,
|
||||
int visible_device_count)
|
||||
: device_(std::move(device)),
|
||||
device_fns_(device_fns),
|
||||
stream_executor_(stream_executor),
|
||||
platform_(platform),
|
||||
platform_fns_(platform_fns),
|
||||
@ -678,10 +718,35 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
|
||||
// Ownership is transferred to the caller.
|
||||
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
|
||||
const override {
|
||||
// TODO(annarev): Figure out if we need to support more description fields.
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
|
||||
internal::DeviceDescriptionBuilder builder;
|
||||
builder.set_name(platform_name_);
|
||||
// TODO(annarev): `Also supports_unified_memory` in DeviceDescription.
|
||||
if (device_.hardware_name != nullptr) {
|
||||
builder.set_name(device_.hardware_name);
|
||||
}
|
||||
if (device_.device_vendor != nullptr) {
|
||||
builder.set_device_vendor(device_.device_vendor);
|
||||
}
|
||||
if (device_.pci_bus_id != nullptr) {
|
||||
builder.set_pci_bus_id(device_.pci_bus_id);
|
||||
}
|
||||
|
||||
if (device_fns_->get_numa_node != nullptr) {
|
||||
int32_t numa_node = device_fns_->get_numa_node(&device_);
|
||||
if (numa_node >= 0) {
|
||||
builder.set_numa_node(numa_node);
|
||||
}
|
||||
}
|
||||
|
||||
if (device_fns_->get_memory_bandwidth != nullptr) {
|
||||
int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_);
|
||||
if (memory_bandwidth >= 0) {
|
||||
builder.set_memory_bandwidth(memory_bandwidth);
|
||||
}
|
||||
}
|
||||
// TODO(annarev): Add gflops field in DeviceDescription and set it here.
|
||||
// TODO(annarev): Perhaps add `supports_unified_memory` in
|
||||
// DeviceDescription.
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
@ -709,6 +774,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
|
||||
|
||||
private:
|
||||
SP_Device device_;
|
||||
SP_DeviceFns* device_fns_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Platform* platform_;
|
||||
SP_PlatformFns* platform_fns_;
|
||||
@ -722,17 +788,20 @@ CPlatform::CPlatform(SP_Platform platform,
|
||||
void (*destroy_platform)(SP_Platform*),
|
||||
SP_PlatformFns platform_fns,
|
||||
void (*destroy_platform_fns)(SP_PlatformFns*),
|
||||
SP_StreamExecutor stream_executor, SP_TimerFns timer_fns)
|
||||
SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
|
||||
SP_TimerFns timer_fns)
|
||||
: platform_(std::move(platform)),
|
||||
destroy_platform_(destroy_platform),
|
||||
platform_fns_(std::move(platform_fns)),
|
||||
destroy_platform_fns_(destroy_platform_fns),
|
||||
device_fns_(std::move(device_fns)),
|
||||
stream_executor_(std::move(stream_executor)),
|
||||
timer_fns_(std::move(timer_fns)),
|
||||
name_(platform.name) {}
|
||||
|
||||
CPlatform::~CPlatform() {
|
||||
executor_cache_.DestroyAllExecutors();
|
||||
platform_fns_.destroy_device_fns(&platform_, &device_fns_);
|
||||
platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
|
||||
platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
|
||||
destroy_platform_(&platform_);
|
||||
@ -781,8 +850,8 @@ port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
|
||||
TF_RETURN_IF_ERROR(ValidateSPDevice(device));
|
||||
|
||||
auto executor = absl::make_unique<CStreamExecutor>(
|
||||
std::move(device), &stream_executor_, &platform_, &platform_fns_,
|
||||
&timer_fns_, name_, platform_.visible_device_count);
|
||||
std::move(device), &device_fns_, &stream_executor_, &platform_,
|
||||
&platform_fns_, &timer_fns_, name_, platform_.visible_device_count);
|
||||
auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
|
||||
config.ordinal);
|
||||
return result;
|
||||
@ -819,6 +888,17 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
|
||||
TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
|
||||
TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
|
||||
|
||||
// Fill SP_DeviceFns creation params
|
||||
SE_CreateDeviceFnsParams device_fns_params{
|
||||
SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE};
|
||||
SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE};
|
||||
device_fns_params.device_fns = &device_fns;
|
||||
|
||||
// Create StreamExecutor
|
||||
platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get());
|
||||
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
|
||||
TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns));
|
||||
|
||||
// Fill stream executor creation params
|
||||
SE_CreateStreamExecutorParams se_params{
|
||||
SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
|
||||
@ -844,7 +924,8 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
|
||||
std::unique_ptr<stream_executor::CPlatform> cplatform(
|
||||
new stream_executor::CPlatform(
|
||||
std::move(platform), params.destroy_platform, std::move(platform_fns),
|
||||
params.destroy_platform_fns, std::move(se), std::move(timer_fns)));
|
||||
params.destroy_platform_fns, std::move(device_fns), std::move(se),
|
||||
std::move(timer_fns)));
|
||||
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
||||
std::move(cplatform)));
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user