resolve conflicts
This commit is contained in:
commit
21135f9a88
12
.bazelrc
12
.bazelrc
@ -49,7 +49,6 @@
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
# mkl: Enable full mkl support.
|
||||
# tensorrt: Enable Tensorrt support.
|
||||
# ngraph: Enable ngraph support.
|
||||
# numa: Enable numa using hwloc.
|
||||
# noaws: Disable AWS S3 storage support
|
||||
# nogcp: Disable GCS support.
|
||||
@ -159,6 +158,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
||||
# environment variable "TF_MKL_ROOT" every time before build.
|
||||
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_openmp=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
@ -172,6 +172,7 @@ build:mkl_threadpool -c opt
|
||||
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
||||
build:mkl_opensource_only --define=build_with_openmp=true
|
||||
build:mkl_opensource_only -c opt
|
||||
|
||||
# Config setting to build with oneDNN for Arm.
|
||||
@ -218,7 +219,6 @@ build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --action_env TF_NEED_ROCM=1
|
||||
|
||||
# Options extracted from configure script
|
||||
build:ngraph --define=with_ngraph_support=true
|
||||
build:numa --define=with_numa_support=true
|
||||
|
||||
# Options to disable default on features
|
||||
@ -283,7 +283,7 @@ build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:linux --host_copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
build:windows --copt=/W0
|
||||
|
||||
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
||||
# _USE_MATH_DEFINES is defined.
|
||||
@ -294,9 +294,11 @@ build:windows --host_copt=/D_USE_MATH_DEFINES
|
||||
build:linux --define=PREFIX=/usr
|
||||
build:linux --define=LIBDIR=$(PREFIX)/lib
|
||||
build:linux --define=INCLUDEDIR=$(PREFIX)/include
|
||||
build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
|
||||
build:macos --define=PREFIX=/usr
|
||||
build:macos --define=LIBDIR=$(PREFIX)/lib
|
||||
build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
@ -600,6 +602,10 @@ build:release_windows_common --config=release_common
|
||||
build:release_windows_common --define=no_tensorflow_py_deps=true
|
||||
build:release_windows_common --announce_rc
|
||||
|
||||
# First available in VS 16.4. Speeds Windows compile times by a lot. See
|
||||
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
build:release_windows_common --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
|
||||
|
||||
build:release_cpu_windows --config=release_windows_common
|
||||
|
||||
build:release_gpu_windows --config=release_windows_common
|
||||
|
10
ISSUES.md
10
ISSUES.md
@ -1,7 +1,9 @@
|
||||
If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance
|
||||
issue or a feature request or a build issue or a documentation issue (for small
|
||||
doc fixes please send a PR instead). 2. Make sure the Issue Template is filled
|
||||
out. 3. The issue should be related to the repo it is created in.
|
||||
If you open a GitHub Issue, here is our policy:
|
||||
|
||||
1. It must be a bug/performance issue or a feature request or a build issue or
|
||||
a documentation issue (for small doc fixes please send a PR instead).
|
||||
1. Make sure the Issue Template is filled out.
|
||||
1. The issue should be related to the repo it is created in.
|
||||
|
||||
**Here's why we have this policy:** We want to focus on the work that benefits
|
||||
the whole community, e.g., fixing bugs and adding features. Individual support
|
||||
|
50
README.md
50
README.md
@ -5,7 +5,6 @@
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
|
||||
|
||||
**`Documentation`** |
|
||||
------------------- |
|
||||
[](https://www.tensorflow.org/api_docs/) |
|
||||
@ -61,6 +60,7 @@ commands.
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
||||
```shell
|
||||
@ -103,23 +103,22 @@ open-source software development:
|
||||
|
||||
### Official Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
@ -133,12 +132,20 @@ Build Type
|
||||
**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
|
||||
**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
||||
**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
|
||||
**Linux aarch64 CPU** Nightly <br> Python 3.6 | [](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
|
||||
**Linux aarch64 CPU** Stable Release | [](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
|
||||
**Linux aarch64 CPU** Nightly (Linaro)<br> Python 3.8 | [](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | [Nightly](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
|
||||
**Linux aarch64 CPU** Stable Release (Linaro) | [](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | Release [1.x & 2.x](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
|
||||
**Linux aarch64 CPU** Nightly (OpenLab)<br> Python 3.6 | [](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
|
||||
**Linux aarch64 CPU** Stable Release (OpenLab) | [](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
|
||||
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
|
||||
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release |  | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
|
||||
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
|
||||
|
||||
### Community Supported Containers
|
||||
|
||||
Container Type | Status | Artifacts
|
||||
----------------------------------------------------------------- | ------ | ---------
|
||||
**TensorFlow aarch64 Neoverse-N1 CPU** Stable (Linaro)<br> Debian | Static | Release [2.3](https://hub.docker.com/r/linaro/tensorflow-arm-neoverse-n1)
|
||||
|
||||
## Resources
|
||||
|
||||
* [TensorFlow.org](https://www.tensorflow.org)
|
||||
@ -151,8 +158,7 @@ Build Type
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the
|
||||
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
|
||||
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
|
211
RELEASE.md
211
RELEASE.md
@ -1,3 +1,91 @@
|
||||
# Release 2.5.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
|
||||
## Known Caveats
|
||||
|
||||
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES).>
|
||||
* <ADDING/BUMPING DEPENDENCIES SHOULD GO HERE>
|
||||
* <KNWON LACK OF SUPPORT ON SOME PLATFORM, SHOULD GO HERE>
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
|
||||
* TPU embedding support
|
||||
* Added `profile_data_directory` to `EmbeddingConfigSpec` in
|
||||
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
||||
gathered at runtime to be used in embedding layer partitioning decisions.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* `tf.keras`:
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* Discretization combiner implemented, with additional arg `epsilon`.
|
||||
|
||||
* `tf.data`:
|
||||
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
|
||||
to control how external state should be handled during dataset
|
||||
serialization or iterator checkpointing.
|
||||
* XLA compilation:
|
||||
* `tf.function(experimental_compile=True)` has become a stable API,
|
||||
renamed `tf.function(jit_compile=True)`.
|
||||
|
||||
* `tf.lite`:
|
||||
* NNAPI
|
||||
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
directly.
|
||||
* 16 bits quantization
|
||||
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||
`tf.GradientTape` inside a `tf.function`.
|
||||
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
|
||||
|
||||
* `tf.summary`:
|
||||
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
||||
(`tf.Graph` or `tf.compat.v1.GraphDef`) as a summary. This is not a
|
||||
replacement for the trace-based API.
|
||||
|
||||
* Set `/d2ReducedOptimizeHugeFunctions` by default for Windows builds. This
|
||||
provides a big compile-time speedup, and effectively raises the minimum
|
||||
supported MSVC version to 16.4 (current: 16.8).
|
||||
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
* TensorRT
|
||||
* Removed the deprecated `session_config` parameter for the TF1-TRT
|
||||
converter `TrtGraphConverter`. Previously, we issued a warning when the
|
||||
value of the parameter is not None.
|
||||
* The TF2-TRT converter `TrtGraphConverterV2` takes an object of class
|
||||
TrtConversionParams as a parameter. Removed three deprecated fields from
|
||||
this class: `rewriter_config_template`, `is_dynamic_op`, and
|
||||
`max_batch_size`. Previously, we issued a warning when the value of
|
||||
`rewriter_config_template` is not None. We issued an error when the
|
||||
value of `is_dynamic_op` is not True. We didn't use the value for
|
||||
`max_batch_size` for building TensorRT engines.
|
||||
* Issue a warning when function get_tensorrt_rewriter_config is used.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
|
||||
# Release 2.4.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
@ -6,6 +94,15 @@
|
||||
|
||||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
|
||||
matmuls and convolutions, due to the use of
|
||||
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
|
||||
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
|
||||
bits of precision. This is unlikely to cause issues in practice for deep
|
||||
learning models. In some cases, TensorFloat-32 is also used for complex64 ops.
|
||||
TensorFloat-32 can be disabled by running
|
||||
`config.experimental.enable_tensor_float_32_execution(False)`. The "Major
|
||||
Features and Improvements" section has more details.
|
||||
* The byte layout for string tensors across the C-API has been updated to match
|
||||
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
|
||||
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
|
||||
@ -54,6 +151,42 @@
|
||||
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
|
||||
use `tf.data.Dataset.from_tensor_slices` instead.
|
||||
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
|
||||
`tf.distribute.StrategyExtended.batch_reduce_to`,
|
||||
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
|
||||
`tf.distribute.experimental.CollectiveHints` is renamed
|
||||
`tf.distribute.experimental.CommunicationOptions`.
|
||||
`tf.distribute.experimental.CollectiveCommunication` is renamed
|
||||
`tf.distribute.experimental.CommunicationImplementation`.
|
||||
* `tf.keras.mixed_precision.experimental`:
|
||||
* `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
|
||||
dtype it will be casted to.
|
||||
* When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
|
||||
float16 or bfloat16 tensor instead of a float32 tensor.
|
||||
* The property
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now
|
||||
a tensor, not a `LossScale` object. This means to get a loss scale of a
|
||||
`LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead
|
||||
of `opt.loss_scale()`.
|
||||
* The property `should_cast_variables` has been removed from
|
||||
`tf.keras.mixed_precision.experimental.Policy`
|
||||
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the
|
||||
`DynamicLossScale`'s multiplier must be 2.
|
||||
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
|
||||
the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of
|
||||
being reused. This means modifying the weights of the `DynamicLossScale`
|
||||
will no longer affect the weights of the LossScaleOptimizer, and vice versa.
|
||||
* The global policy can no longer be set to a non-floating point policy in
|
||||
`tf.keras.mixed_precision.experimental.set_policy`
|
||||
* In `Layer.call`, `AutoCastVariable`s will no longer be casted within
|
||||
`MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a
|
||||
thread local variable is used to determine whether `AutoCastVariable`s are
|
||||
casted, and those two functions run with a different thread. Note this only
|
||||
applies if one of these two functions is called within `Layer.call`; if one
|
||||
of those two functions calls `Layer.call`, `AutoCastVariable`s will still be
|
||||
casted.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
@ -65,9 +198,40 @@
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
|
||||
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
|
||||
* Support for
|
||||
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
|
||||
on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a
|
||||
math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as
|
||||
matrix multiplications and convolutions, to run much faster on Ampere GPUs but
|
||||
with reduced precision. This reduced precision has not been found to effect
|
||||
convergence quality of deep learning models in practice. TensorFloat-32 is
|
||||
enabled by default, but can be disabled with
|
||||
`tf.config.experimental.enable_tensor_float_32_execution`.
|
||||
|
||||
* `tf.distribute`:
|
||||
* `MultiWorkerMirroredStrategy` is graduated out of experimental.
|
||||
* Peer failure will no longer cause the cluster to hang.
|
||||
* Major issues with saving are fixed.
|
||||
* See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
|
||||
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
|
||||
* The `tf.keras.mixed_precision` API has been made non-experimental. The major
|
||||
changes to the new non-experimental API are:
|
||||
* `tf.keras.mixed_precision.Policy` no longer takes in a
|
||||
`tf.mixed_precision.experimental.LossScale` in the constructor, and no
|
||||
longer has a `LossScale` associated with it. Instead, `Model.compile` will
|
||||
automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic
|
||||
loss scaling if `Policy.name` is "mixed_float16".
|
||||
* `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in
|
||||
different arguments. In particular, it no longer takes in a `LossScale`, and
|
||||
there is no longer a `LossScale` associated with the `LossScaleOptimizer`.
|
||||
Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss
|
||||
scaling. See the documentation of
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on
|
||||
the differences between the experimental `LossScaleOptimizer` and the new
|
||||
non-experimental `LossScaleOptimizer`.
|
||||
* `tf.mixed_precision.experimental.LossScale` and its subclasses are
|
||||
deprecated, as all of its functionality now exists within
|
||||
`tf.keras.mixed_precision.LossScaleOptimizer`
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
@ -117,6 +281,10 @@
|
||||
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||
* Fixes a segfault in `tf.quantization.quantize_and_dequantize`
|
||||
([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
|
||||
* Fixes an undefined behavior float cast causing a crash
|
||||
([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
|
||||
* TF Core:
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be
|
||||
used as type annotation for variables representing a Tensor or a value
|
||||
@ -138,6 +306,8 @@
|
||||
stateful ops.
|
||||
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||
usage of the device.
|
||||
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
||||
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
|
||||
* `tf.data`:
|
||||
* tf.data service:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
@ -182,7 +352,16 @@
|
||||
how many times the function is called, and independent of global seed
|
||||
settings.
|
||||
* `tf.distribute`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* (Experimental) Parameter server training:
|
||||
* Replaced the existing
|
||||
`tf.distribute.experimental.ParameterServerStrategy` symbol with
|
||||
a new class that is for parameter server training in TF2. Usage with
|
||||
the old symbol, usually with Estimator, should be replaced with
|
||||
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
|
||||
* Added `tf.distribute.experimental.coordinator.*` namespace,
|
||||
including the main API `ClusterCoordinator` for coordinating the
|
||||
training cluster, the related data structure `RemoteValue`
|
||||
and `PerWorkerValue`.
|
||||
* `tf.keras`:
|
||||
* Improvements from the functional API refactoring:
|
||||
* Functional model construction does not need to maintain a global
|
||||
@ -217,6 +396,8 @@
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* TextVectorization can now accept a vocabulary list or file as an
|
||||
init arg.
|
||||
* TextVectorization, StringLookup, and IntegerLookup can now accept a
|
||||
vocabulary file via the `set_vocab_from_file` method.
|
||||
* Normalization can now accept mean and variance values as init args.
|
||||
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||
accepts a `return_attention_scores` argument. When set to
|
||||
@ -224,15 +405,28 @@
|
||||
argument.
|
||||
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||
with the same implementation as their `tf.losses` equivalent.
|
||||
* For Keras model, the individual call of `Model.evaluate` uses no cached
|
||||
data for evaluation, while `Model.fit` uses cached data when
|
||||
`validation_data` arg is provided for better performance.
|
||||
* Added a `save_traces` argument to `model.save`/
|
||||
`tf.keras.models.save_model` which determines whether the SavedModel
|
||||
format stores the Keras model/layer call functions. The traced functions
|
||||
allow Keras to revive custom models and layers without the original
|
||||
class definition, but if this isn't required the tracing can be
|
||||
disabled with the added option.
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
performance.
|
||||
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
||||
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
||||
* AutoGraph now allows creating new symbols inside a TensorFlow loop, if
|
||||
the values of these symbols at an iteration does not depend on the
|
||||
previous iteration. These types of loops must run at least one
|
||||
iteration, and will raise a runtime error otherwise.
|
||||
* Variables contained in `tf.Module`s that are set as attributes of
|
||||
custom Keras `Layer`s and `Model`s are now tracked in
|
||||
the properties `layer.trainable_variables` and
|
||||
`layer.non_trainable_variables`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -269,6 +463,7 @@
|
||||
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||
string to be joined is empty.
|
||||
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `tf.random`:
|
||||
@ -277,7 +472,7 @@
|
||||
|
||||
* Math and Linear Algebra:
|
||||
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
|
||||
|
||||
* TPU Enhancements:
|
||||
|
||||
@ -323,6 +518,12 @@
|
||||
didn't have the keys sorted, the keys and values were not being printed
|
||||
in accordance with their correct mapping.
|
||||
|
||||
* `TensorRT`
|
||||
|
||||
* We now issue a warning when the `session_config` parameter for the TF1
|
||||
converter is used or the `rewrite_config_template` field in the TF2
|
||||
converter parameter object is used.
|
||||
|
||||
* Other:
|
||||
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
@ -331,6 +532,8 @@
|
||||
context.
|
||||
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
|
||||
rollout the new MLIR TPU bridge.
|
||||
* Added `tf.experimental.register_filesystem_plugin` to load modular
|
||||
filesystem plugins from Python
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
@ -703,6 +906,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
* Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights.
|
||||
* Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights.
|
||||
* Mutable tables now restore checkpointed values when loaded from SavedModel.
|
||||
* The user object metadata field in the SavedModel proto has been deprecated as part of the updates to Keras SavedModel. Keras was the only consumer of this field prior to the update.
|
||||
* GPU
|
||||
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
|
||||
* Remove environmental variable `TF_USE_CUDNN`.
|
||||
@ -731,6 +935,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
|
||||
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
|
||||
* Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices.
|
||||
|
||||
### `tf.keras`:
|
||||
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
|
||||
|
20
WORKSPACE
20
WORKSPACE
@ -113,26 +113,10 @@ http_archive(
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
||||
grpc_deps()
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||
"apple_rules_dependencies",
|
||||
)
|
||||
|
||||
apple_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||
"apple_support_dependencies",
|
||||
)
|
||||
|
||||
apple_support_dependencies()
|
||||
|
||||
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||
|
||||
bazel_version_repository(name = "bazel_version")
|
||||
load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
|
||||
grpc_extra_deps()
|
||||
|
||||
load("//third_party/googleapis:repository_rules.bzl", "config_googleapis")
|
||||
|
||||
|
58
configure.py
58
configure.py
@ -55,16 +55,16 @@ NCCL_LIB_PATHS = [
|
||||
|
||||
# List of files to configure when building Bazel on Apple platforms.
|
||||
APPLE_BAZEL_FILES = [
|
||||
'tensorflow/lite/experimental/ios/BUILD',
|
||||
'tensorflow/lite/experimental/objc/BUILD',
|
||||
'tensorflow/lite/experimental/swift/BUILD',
|
||||
'tensorflow/lite/ios/BUILD',
|
||||
'tensorflow/lite/objc/BUILD',
|
||||
'tensorflow/lite/swift/BUILD',
|
||||
'tensorflow/lite/tools/benchmark/experimental/ios/BUILD'
|
||||
]
|
||||
|
||||
# List of files to move when building for iOS.
|
||||
IOS_FILES = [
|
||||
'tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec',
|
||||
'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec',
|
||||
'tensorflow/lite/objc/TensorFlowLiteObjC.podspec',
|
||||
'tensorflow/lite/swift/TensorFlowLiteSwift.podspec',
|
||||
]
|
||||
|
||||
|
||||
@ -1163,49 +1163,18 @@ def set_system_libs_flag(environ_cp):
|
||||
syslibs = ','.join(sorted(syslibs.split()))
|
||||
write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
|
||||
|
||||
if 'PREFIX' in environ_cp:
|
||||
write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX'])
|
||||
if 'LIBDIR' in environ_cp:
|
||||
write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR'])
|
||||
if 'INCLUDEDIR' in environ_cp:
|
||||
write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR'])
|
||||
|
||||
|
||||
def is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
"""Check to see if the system supports /d2ReducedOptimizeHugeFunctions.
|
||||
|
||||
The above compiler flag is a new compiler flag introduced to the Visual Studio
|
||||
compiler in version 16.4 (available in Visual Studio 2019, Preview edition
|
||||
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
|
||||
compile times, but until 16.4 is officially released, we can't depend on it.
|
||||
|
||||
See also
|
||||
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
Because it's very annoying to check this manually (to check the MSVC installed
|
||||
versions, you need to use the registry, and it's not clear if Bazel will be
|
||||
using that install version anyway), we expect enviroments who know they may
|
||||
use this flag to export TF_VC_VERSION=16.4
|
||||
|
||||
TODO(angerson, gunan): Remove this function when TensorFlow's minimum VS
|
||||
version is upgraded to 16.4.
|
||||
|
||||
Arguments:
|
||||
environ_cp: Environment of the current execution
|
||||
|
||||
Returns:
|
||||
boolean, whether or not /d2ReducedOptimizeHugeFunctions is available on this
|
||||
machine.
|
||||
"""
|
||||
return float(environ_cp.get('TF_VC_VERSION', '0')) >= 16.4
|
||||
for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'):
|
||||
if varname in environ_cp:
|
||||
write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname]))
|
||||
|
||||
|
||||
def set_windows_build_flags(environ_cp):
|
||||
"""Set Windows specific build options."""
|
||||
if is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
write_to_bazelrc(
|
||||
'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'
|
||||
)
|
||||
|
||||
# First available in VS 16.4. Speeds up Windows compile times by a lot. See
|
||||
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
# pylint: disable=line-too-long
|
||||
write_to_bazelrc('build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions')
|
||||
|
||||
if get_var(
|
||||
environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
|
||||
@ -1487,7 +1456,6 @@ def main():
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
||||
config_info_line('numa', 'Build with NUMA support.')
|
||||
config_info_line(
|
||||
'dynamic_kernels',
|
||||
|
@ -3,6 +3,7 @@
|
||||
# learning applications.
|
||||
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
@ -22,10 +23,6 @@ load(
|
||||
"//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
|
||||
"TENSORFLOW_API_INIT_FILES_V1", # @unused
|
||||
)
|
||||
load(
|
||||
"//third_party/ngraph:build_defs.bzl",
|
||||
"if_ngraph",
|
||||
)
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl_ml",
|
||||
@ -75,6 +72,14 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Config setting that disables the default logger, only logging
|
||||
# to registered TFLogSinks
|
||||
config_setting(
|
||||
name = "no_default_logger",
|
||||
define_values = {"no_default_logger": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Config setting for determining if we are building for Android.
|
||||
config_setting(
|
||||
name = "android",
|
||||
@ -238,6 +243,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_mips64",
|
||||
values = {"cpu": "mips64"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "debug",
|
||||
values = {
|
||||
@ -465,14 +476,6 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag is set from the configure step when the user selects with nGraph option.
|
||||
# By default it should be false
|
||||
config_setting(
|
||||
name = "with_ngraph_support",
|
||||
values = {"define": "with_ngraph_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag specifies whether TensorFlow 2.0 API should be built instead
|
||||
# of 1.* API. Note that TensorFlow 2.0 API is currently under development.
|
||||
config_setting(
|
||||
@ -563,18 +566,47 @@ selects.config_setting_group(
|
||||
],
|
||||
)
|
||||
|
||||
# 'enable_registration_v2' opts-in to a different implementation of op and
|
||||
# kernel registration - REGISTER_OP, REGISTER_KERNEL_BUILDER, etc.
|
||||
#
|
||||
# This setting is currently experimental. The 'v2' implementation does _not_
|
||||
# correspond to a particular, finalized design; rather, it relates to
|
||||
# developing one.
|
||||
#
|
||||
# The current aim of the 'v2' implementation is to allow 'unused' ops and
|
||||
# kernels to be discarded by the linker (to the benefit of binary size).
|
||||
bool_flag(
|
||||
name = "enable_registration_v2",
|
||||
build_setting_default = False,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "registration_v1",
|
||||
flag_values = {":enable_registration_v2": "False"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "registration_v2",
|
||||
flag_values = {":enable_registration_v2": "True"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||
# Instead, please use public APIs or public build rules TF provides.
|
||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||
# TODO(b/173549186): Move Google-internal TF code out of learning/brain
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = ["//tensorflow/..."],
|
||||
packages = [
|
||||
"//learning/brain/mlir/...",
|
||||
"//learning/lib/ami/simple_ml/...",
|
||||
"//tensorflow/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "ndarray_tensor_allow_list",
|
||||
packages = ["//learning/pathways/..."],
|
||||
)
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
@ -605,7 +637,7 @@ bzl_library(
|
||||
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
||||
"//third_party/mkl:build_defs_bzl",
|
||||
"//third_party/mkl_dnn:build_defs_bzl",
|
||||
"//third_party/ngraph:build_defs_bzl",
|
||||
"@bazel_skylib//rules:common_settings",
|
||||
"@local_config_cuda//cuda:build_defs_bzl",
|
||||
"@local_config_rocm//rocm:build_defs_bzl",
|
||||
"@local_config_tensorrt//:build_defs_bzl",
|
||||
@ -706,6 +738,10 @@ tf_cc_shared_object(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/c:kernels_hdrs",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:ops_hdrs",
|
||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
@ -809,7 +845,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:tensorflow",
|
||||
] + if_ngraph(["@ngraph_tf//:ngraph_tf"]),
|
||||
],
|
||||
)
|
||||
|
||||
# ** Targets for Windows build (start) **
|
||||
|
@ -199,9 +199,11 @@ tf_cuda_library(
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":logging",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/c/experimental/filesystem:modular_filesystem",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc:gradients",
|
||||
"//tensorflow/cc:ops",
|
||||
@ -511,6 +513,19 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kernels_hdrs",
|
||||
hdrs = ["kernels.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "kernels",
|
||||
srcs = [
|
||||
@ -528,13 +543,17 @@ tf_cuda_library(
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
":tf_tensor",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_internal",
|
||||
],
|
||||
}),
|
||||
)
|
||||
@ -564,6 +583,16 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops_hdrs",
|
||||
hdrs = ["ops.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
@ -2606,4 +2607,14 @@ void TF_RegisterLogListener(void (*listener)(const char*)) {
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
}
|
||||
|
||||
void TF_RegisterFilesystemPlugin(const char* plugin_filename,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"FileSystem plugin functionality is not supported on mobile");
|
||||
#else
|
||||
status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
|
||||
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -1205,7 +1205,7 @@ typedef struct TF_Session TF_Session;
|
||||
// Return a new execution session with the associated graph, or NULL on
|
||||
// error. Does not take ownership of any input parameters.
|
||||
//
|
||||
// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be
|
||||
// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be
|
||||
// kept alive for the lifetime of the returned TF_Session. New nodes can still
|
||||
// be added to `graph` after this call.
|
||||
TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph,
|
||||
@ -1577,6 +1577,13 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
|
||||
TF_CAPI_EXPORT extern void TF_RegisterLogListener(
|
||||
void (*listener)(const char*));
|
||||
|
||||
// Register a FileSystem plugin from filename `plugin_filename`.
|
||||
//
|
||||
// On success, place OK in status.
|
||||
// On failure, place an error status in status.
|
||||
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
|
||||
const char* plugin_filename, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
collective_executor_handle->get()->StartAbort(status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status) {
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
||||
tensorflow::Notification done;
|
||||
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
|
||||
task, [&done, status](const Status& s) {
|
||||
task, timeout_in_ms, [&done, status](const Status& s) {
|
||||
status->status = s;
|
||||
done.Notify();
|
||||
});
|
||||
|
@ -86,7 +86,7 @@ TF_CAPI_EXPORT void TF_SetXlaConstantFoldingDisabled(
|
||||
|
||||
// Create a serialized tensorflow.ConfigProto proto, where:
|
||||
//
|
||||
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
|
||||
// a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if
|
||||
// `enable_xla_compilation` is non-zero, and OFF otherwise.
|
||||
// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
|
||||
// c) ConfigProto.device_count is set to `num_cpu_devices`.
|
||||
@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
// Checks the health of collective ops peers. Explicit health check is needed in
|
||||
// multi worker collective ops to detect failures in the cluster. If a peer is
|
||||
// down, collective ops may hang.
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||
TF_Status* status);
|
||||
|
||||
// Information about the shape of a Tensor and its type.
|
||||
struct TF_ShapeAndType {
|
||||
|
@ -10,6 +10,9 @@ load(
|
||||
"tf_cuda_library",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
@ -48,6 +51,7 @@ tf_cuda_library(
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":immediate_execution_distributed_manager",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
@ -67,6 +71,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
@ -94,6 +99,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
] + internal_tfrt_deps(),
|
||||
alwayslink = 1,
|
||||
@ -106,6 +112,7 @@ filegroup(
|
||||
"abstract_function.h",
|
||||
"abstract_operation.h",
|
||||
"abstract_tensor_handle.h",
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
@ -114,6 +121,7 @@ filegroup(
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_distributed_manager.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tape.h",
|
||||
@ -171,6 +179,7 @@ cc_library(
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
@ -219,6 +228,34 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unified_api_testutil",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"unified_api_testutil.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"unified_api_testutil.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gradients_test",
|
||||
size = "small",
|
||||
@ -235,6 +272,7 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":unified_api_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
@ -255,6 +293,29 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "unified_api_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"unified_api_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156
|
||||
deps = [
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":unified_api_testutil",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_util",
|
||||
srcs = [
|
||||
@ -408,6 +469,7 @@ tf_cuda_cc_test(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"nomac",
|
||||
"no_cuda_asan", # b/173825513
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
@ -444,8 +506,10 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -524,6 +588,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_distributed_manager",
|
||||
hdrs = ["immediate_execution_distributed_manager.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
@ -532,12 +609,14 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_distributed_manager",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
@ -638,6 +717,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "tfe_tensorhandle_internal_hdrs_only",
|
||||
extra_deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_test_util",
|
||||
testonly = 1,
|
||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
|
@ -30,7 +30,14 @@ namespace tensorflow {
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
enum AbstractOperationKind {
|
||||
kGraph,
|
||||
kMlir,
|
||||
kEager,
|
||||
kTfrt,
|
||||
kTape,
|
||||
kOpHandler
|
||||
};
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
@ -32,6 +34,9 @@ class AbstractTensorHandle : public core::RefCounted {
|
||||
public:
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
// Returns tensor shape. If tensor has unknown rank, shape remains untouched.
|
||||
virtual tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const = 0;
|
||||
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
|
@ -21,16 +21,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
@ -39,58 +34,40 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
// "tensorflow/core/platform/platform.h" must be included first before using
|
||||
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h"
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
@ -99,611 +76,6 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
return (d == nullptr) ? "cpu:0" : d->name();
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||
tensorflow::mutex remote_devices_mu;
|
||||
int num_added_workers = added_remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_added_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_added_workers);
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
tensorflow::NewRemoteDevices(
|
||||
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
|
||||
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
||||
const tensorflow::Status& s,
|
||||
std::vector<tensorflow::Device*>* devices) {
|
||||
statuses[i] = s;
|
||||
if (s.ok()) {
|
||||
tensorflow::mutex_lock l(remote_devices_mu);
|
||||
for (tensorflow::Device* d : *devices) {
|
||||
remote_devices.emplace_back(d);
|
||||
}
|
||||
}
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
||||
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
|
||||
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
|
||||
remote_device_mgr.get()));
|
||||
*device_mgr = std::move(remote_device_mgr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status RemoveRemoteDevicesFromMgr(
|
||||
const std::vector<string>& removed_remote_workers,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
const std::vector<tensorflow::Device*> remote_devices =
|
||||
(remote_device_mgr->ListDevices());
|
||||
std::vector<tensorflow::Device*> devices_to_remove;
|
||||
for (tensorflow::Device* d : remote_devices) {
|
||||
for (const string& remote_worker : removed_remote_workers) {
|
||||
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
|
||||
d->name())) {
|
||||
devices_to_remove.emplace_back(d);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
|
||||
const string& local_worker,
|
||||
std::vector<string>* remote_workers) {
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(server);
|
||||
if (grpc_server == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
|
||||
}
|
||||
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
|
||||
remote_workers->erase(
|
||||
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
|
||||
remote_workers->end());
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void DifferentiateWorkerLists(const std::vector<string>* current_list,
|
||||
const std::vector<string>* new_list,
|
||||
std::vector<string>* added,
|
||||
std::vector<string>* removed,
|
||||
std::vector<string>* existing) {
|
||||
// Get STL set_difference and set_intersection with one list traversal.
|
||||
// Similar to the set_difference library function, the input lists
|
||||
// (`current_list` and `new_list`) must be sorted before calling the function.
|
||||
added->resize(new_list->size());
|
||||
removed->resize(current_list->size());
|
||||
existing->resize(current_list->size());
|
||||
std::vector<string>::const_iterator curr_it = current_list->begin();
|
||||
std::vector<string>::const_iterator new_it = new_list->begin();
|
||||
std::vector<string>::iterator added_it = added->begin();
|
||||
std::vector<string>::iterator removed_it = removed->begin();
|
||||
std::vector<string>::iterator existing_it = existing->begin();
|
||||
while (curr_it != current_list->end() && new_it != new_list->end()) {
|
||||
if (*curr_it < *new_it) {
|
||||
*removed_it++ = *curr_it++;
|
||||
} else if (*curr_it > *new_it) {
|
||||
*added_it++ = *new_it++;
|
||||
} else {
|
||||
*existing_it++ = *curr_it++;
|
||||
new_it++;
|
||||
}
|
||||
}
|
||||
removed_it = std::copy(curr_it, current_list->end(), removed_it);
|
||||
added_it = std::copy(new_it, new_list->end(), added_it);
|
||||
added->resize(added_it - added->begin());
|
||||
removed->resize(removed_it - removed->begin());
|
||||
existing->resize(existing_it - existing->begin());
|
||||
}
|
||||
|
||||
tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* client_cache,
|
||||
std::vector<string>* replaced_workers) {
|
||||
tensorflow::BlockingCounter counter(existing_workers->size());
|
||||
std::vector<tensorflow::Status> statuses(existing_workers->size());
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context_id);
|
||||
std::vector<tensorflow::eager::KeepAliveResponse> responses(
|
||||
existing_workers->size());
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] =
|
||||
client_cache->GetClient(existing_workers->at(i), &eager_client);
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &responses[i],
|
||||
[i, &statuses, &counter](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
// If the RPC fails (indicating that the requested ID doesn't exist on
|
||||
// remote), or the returned view ID is not equal to the local one
|
||||
// (indicating that the remote worker has a stale view of cluster), treat
|
||||
// the worker as replaced.
|
||||
if (!statuses[i].ok() ||
|
||||
responses[i].context_view_id() != context_view_id) {
|
||||
replaced_workers->emplace_back(existing_workers->at(i));
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
lazy_copy_remote_function_inputs);
|
||||
|
||||
eager_client->CreateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
tensorflow::StatusGroup sg;
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
|
||||
sg.Update(statuses[i]);
|
||||
}
|
||||
}
|
||||
return sg.as_summary_status();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, bool reset_context) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
// message.
|
||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||
LOG(ERROR) << _status.error_message(); \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
string worker_name =
|
||||
tensorflow::strings::StrCat("/job:", server_def.job_name(),
|
||||
"/replica:0/task:", server_def.task_index());
|
||||
|
||||
// List of current remote workers before updating server_def. Unused if
|
||||
// resetting the server_def.
|
||||
std::vector<string> curr_remote_workers;
|
||||
// List of updated remote workers.
|
||||
std::vector<string> remote_workers;
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||
// if the server is a GRPC server or not.
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
}
|
||||
|
||||
tensorflow::uint64 context_id = context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||
if (reset_context) {
|
||||
context_id = tensorflow::EagerContext::NewContextId();
|
||||
context_view_id = 0;
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// For cluster update, use a status group to aggregate statuses from
|
||||
// * adding and removing remote devices
|
||||
// * creating remote contexts on newly added workers
|
||||
// * updating remote contexts on existing workers
|
||||
// * updating the master context
|
||||
// Note that we should not return immediately on errors in the middle of these
|
||||
// updates to prevent cluster from having inconsistent context views.
|
||||
//
|
||||
// Unused if `reset_context` is True.
|
||||
tensorflow::StatusGroup sg;
|
||||
|
||||
// When updating an existing context, populate the following lists with:
|
||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
|
||||
// * replaced_workers: workers with the same task names and potentially the
|
||||
// same `hostname:port`s, but replaced by different processes
|
||||
std::vector<string> added_workers;
|
||||
std::vector<string> removed_workers;
|
||||
std::vector<string> existing_workers;
|
||||
std::vector<string> replaced_workers;
|
||||
|
||||
// New remote device manager created for new server_def. Unused if updating
|
||||
// server_def.
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||
if (remote_device_mgr == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||
"Updating context with an invalid set of remote devices."));
|
||||
}
|
||||
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
|
||||
std::sort(remote_workers.begin(), remote_workers.end());
|
||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
sg.Update(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Updating cluster with following changes";
|
||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||
for (const string& w : removed_workers)
|
||||
VLOG(1) << " Removed worker " << w;
|
||||
for (const string& w : replaced_workers)
|
||||
VLOG(1) << " Replaced worker " << w;
|
||||
}
|
||||
if (!replaced_workers.empty()) {
|
||||
// Treat replaced workers as removed then added back, so that we recreate
|
||||
// remote devices and contexts, and re-register functions on those workers
|
||||
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
added_workers.insert(added_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
for (const string& w : replaced_workers) {
|
||||
existing_workers.erase(
|
||||
std::remove(existing_workers.begin(), existing_workers.end(), w),
|
||||
existing_workers.end());
|
||||
}
|
||||
}
|
||||
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||
grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
}
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendezvous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
// Initialize remote eager workers.
|
||||
if (reset_context) {
|
||||
const tensorflow::Status s = CreateRemoteContexts(
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request);
|
||||
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
||||
// the CreateRemoteContexts to fail. We currently only log instead of
|
||||
// directly returning the error, since returning here will cause the server
|
||||
// object to be destroyed (which currently CHECK-fails). The client will
|
||||
// see additional errors if ops are subsequently sent to the failed workers.
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
LOG(ERROR) << "Error when creating contexts on remote targets: "
|
||||
<< s.error_message()
|
||||
<< "\nExecuting remote ops or functions on these remote "
|
||||
"targets will fail.";
|
||||
}
|
||||
} else {
|
||||
if (sg.ok()) {
|
||||
// Create remote contexts on the newly added workers only if the master
|
||||
// has collected all device information from them (i.e., the
|
||||
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
|
||||
// GetAllRemoteDevices can still fail even with RPCs configured to wait
|
||||
// until the remote workers to become alive. If the master creates remote
|
||||
// contexts on the workers whose devices are still not collected, those
|
||||
// workers will be treated as existing workers subsequently, so the master
|
||||
// will never get devices from them even with retrying UpdateServerDef.
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
}
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
// The master's context_view_id will be incremented by one in the
|
||||
// UpdateRemoteMaster call later. We want existing workers to also have
|
||||
// the updated context_view_id, so we must set their context_view_id to
|
||||
// the master's current context_view_id + 1.
|
||||
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, context);
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||
std::move(remote_mgr)));
|
||||
|
||||
// NOTE: We start the server after all other initialization, because the
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
} else {
|
||||
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
/*isolate_session_state=*/true));
|
||||
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||
std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -730,11 +102,21 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||
tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
tfrt_context->SetDistributedManager(
|
||||
std::make_unique<tfrt::tf::DistributedManagerContextInterface>(
|
||||
tfrt_context->GetCoreRuntime()->GetHostContext()));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return tensorflow::wrap(tfrt_context);
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
#endif
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||
@ -746,13 +128,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r));
|
||||
/*device_mgr_owned*/ true, r);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
eager_context->SetDistributedManager(
|
||||
std::make_unique<tensorflow::EagerContextDistributedManager>(
|
||||
eager_context));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return tensorflow::wrap(eager_context);
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -790,26 +177,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/true, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -834,14 +204,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/false, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -853,43 +218,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
// TODO(yuefengz): support partially specified `worker_name`.
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
status->status = context->GetClient(worker_name, &eager_client);
|
||||
if (!status->status.ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Send a rpc request to the worker to check aliveness.
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context->GetContextId());
|
||||
tensorflow::eager::KeepAliveResponse response;
|
||||
|
||||
tensorflow::Status keep_alive_status;
|
||||
tensorflow::Notification done;
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &response,
|
||||
[&keep_alive_status, &done](const tensorflow::Status& s) {
|
||||
keep_alive_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
|
||||
status->status = tensorflow::Status::OK();
|
||||
|
||||
// If `context_id` doesn't exist on the remote worker, an InvalidArgument
|
||||
// error will return. But this still indicates that the remote worker is
|
||||
// alive.
|
||||
if (keep_alive_status.ok() ||
|
||||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
||||
return true;
|
||||
} else {
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << keep_alive_status.error_message();
|
||||
return false;
|
||||
}
|
||||
bool is_alive;
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
|
||||
worker_name, &is_alive);
|
||||
return is_alive;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -1445,13 +778,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
auto* context = tensorflow::unwrap(ctx);
|
||||
status->status = context->AsyncWait();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
||||
context->ClearRunMetadata();
|
||||
auto run_metadata = context->ExportRunMetadata();
|
||||
status->status = MessageToBuffer(*run_metadata, buf);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::unwrap(h)->DeviceType(&status->status);
|
||||
}
|
||||
|
||||
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
return tensorflow::unwrap(h)->DeviceId(&status->status);
|
||||
}
|
||||
|
@ -481,7 +481,7 @@ typedef struct TFE_CustomDevice {
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
|
||||
//
|
||||
// The custom device defines copy operations for moving TensorHandles on and
|
||||
// off, and an an execution operation for named operations. Often execution will
|
||||
// off, and an execution operation for named operations. Often execution will
|
||||
// simply wrap op execution on one or more physical devices.
|
||||
//
|
||||
// device_info is an opaque caller-defined type stored with the custom device
|
||||
@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the device type of the operation that produced `h`.
|
||||
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
// Returns the device ID of the operation that produced `h`.
|
||||
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleNullptr) {
|
||||
TFE_TensorHandle* h = nullptr;
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_type, nullptr);
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
int device_id = TFE_TensorHandleDeviceID(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_id, -1);
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
|
||||
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
|
||||
|
||||
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
TFE_DeleteOp(shape_op);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleDefaults) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
|
||||
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
|
||||
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id) << device_id;
|
||||
|
||||
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
|
||||
h_default, ctx, "/device:CPU:0", status.get());
|
||||
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
|
||||
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
|
||||
|
||||
TFE_DeleteTensorHandle(h_default);
|
||||
TFE_DeleteTensorHandle(h_cpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
|
||||
|
||||
// Run a function containing a MatMul op and check its output.
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
|
||||
// which creates a remote remote input, to simulate a scenario that the remote
|
||||
// input is not ready when we start running an op or a function.
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc requests before the one
|
||||
// which creates a remote input, to simulate a scenario that the remote input
|
||||
// is not ready when we start running an op or a function.
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
bool heavy_load_on_streaming_rpc,
|
||||
bool remote_func_outputs = false);
|
||||
|
@ -696,7 +696,7 @@ TEST(CAPI, ExecuteAddForwardAsync) {
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// TODO(b/153349425): Add add forwarding tests for TFRT
|
||||
// TODO(b/153349425): Add forwarding tests for TFRT
|
||||
TEST(CAPI, ExecuteAddTfrt) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
EXPECT_EQ(nullptr, t);
|
||||
const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
|
||||
const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
|
||||
EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
|
||||
<< TF_Message(status);
|
||||
// Since error is not cleared, the following copy with correct device will
|
||||
|
@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
TF_DataType dtype, TF_Shape shape,
|
||||
TF_Status* s) {
|
||||
DCHECK_GE(shape.num_dims, -1);
|
||||
TracingTensorHandle* t;
|
||||
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
|
||||
if (!tracing_ctx) {
|
||||
@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
"TF_AddFunctionParameter must be called on a TracingContext."));
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::PartialTensorShape partial_shape;
|
||||
if (shape.num_dims != -1) {
|
||||
DCHECK(shape.dim_sizes != nullptr);
|
||||
Status status = tensorflow::PartialTensorShape::MakePartialShape(
|
||||
reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
|
||||
&partial_shape);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(s, status);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
Set_TF_Status_from_Status(
|
||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
|
||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
|
||||
&t));
|
||||
return wrap(t);
|
||||
}
|
||||
|
||||
|
@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
// Represents a (partially-defined) shape.
|
||||
typedef struct TF_Shape {
|
||||
int num_dims; // Must be >= -1; -1 represents unknown rank.
|
||||
int64_t* dim_sizes;
|
||||
} TF_Shape;
|
||||
|
||||
// Add a new parameter to a TensorFlow Function.
|
||||
// TODO(aminim): what about shape?
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s);
|
||||
TF_DataType dtype, TF_Shape shape,
|
||||
TF_Status* s);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -43,22 +45,50 @@ class GraphContext;
|
||||
class GraphOperation;
|
||||
class GraphTensor;
|
||||
|
||||
auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
|
||||
auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
|
||||
|
||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||
// into the list of outputs for the operation.
|
||||
class GraphTensor : public TracingTensorHandle {
|
||||
public:
|
||||
explicit GraphTensor(TF_Output output)
|
||||
: TracingTensorHandle(kGraph), output_(output) {}
|
||||
explicit GraphTensor(TF_Output output, TF_Graph* graph)
|
||||
: TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
|
||||
|
||||
tensorflow::DataType DataType() const override {
|
||||
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
||||
}
|
||||
|
||||
tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const override {
|
||||
DCHECK(shape != nullptr);
|
||||
TF_Status status;
|
||||
int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
|
||||
DCHECK_GE(num_dims, -1);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||
if (num_dims == kUnknownRank) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int64> dims(num_dims, kUnknownDim);
|
||||
TF_GraphGetTensorShape(graph_, output_,
|
||||
reinterpret_cast<int64_t*>(dims.data()), num_dims,
|
||||
&status);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||
TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_Output output_;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kGraph;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_Graph* graph_; // For shape inference.
|
||||
};
|
||||
|
||||
// GraphOperation wraps and populates a TF_OperationDescription.
|
||||
@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation {
|
||||
TF_DeleteStatus(s);
|
||||
*num_retvals = TF_OperationNumOutputs(operation);
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new GraphTensor({operation, i});
|
||||
retvals[i] = new GraphTensor({operation, i}, g_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -326,12 +356,18 @@ class GraphContext : public TracingContext {
|
||||
return new GraphOperation(graph_.get());
|
||||
}
|
||||
|
||||
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
|
||||
Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||
TracingTensorHandle** output) override {
|
||||
TracingOperationPtr operation(CreateOperation());
|
||||
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
|
||||
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
|
||||
if (!shape.unknown_rank()) {
|
||||
TF_RETURN_IF_ERROR(operation->SetAttrShape(
|
||||
"shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
|
||||
shape.dims()));
|
||||
}
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(operation->Execute(
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -107,7 +108,8 @@ class TracingContext : public AbstractContext {
|
||||
|
||||
public:
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
|
||||
virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||
TracingTensorHandle**) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
|
@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically unstable
|
||||
enable_tensor_float_32_execution(false);
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
|
@ -226,7 +226,7 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
// the state of the ForwardOperation and call the tape as appropriate.
|
||||
// These APIs are mainly to faciliate testing and are subject to change.
|
||||
// These APIs are mainly to facilitate testing and are subject to change.
|
||||
namespace internal {
|
||||
Status Reset(AbstractOperation* op_, const char* op,
|
||||
const char* raw_device_name, ForwardOperation* forward_op_) {
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
@ -62,23 +63,29 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Log1p", Log1pRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("DivNoNan", DivNoNanRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(add_outputs),
|
||||
"Add")); // Compute x+y.
|
||||
@ -97,7 +104,6 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -106,13 +112,15 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
// return grad(y, {inputs[0]})
|
||||
Status ExpGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> exp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
@ -128,7 +136,6 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
exp_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -137,13 +144,15 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
// return grad(y, {inputs[0]})
|
||||
Status SqrtGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
@ -159,7 +168,6 @@ Status SqrtGradModel(AbstractContext* ctx,
|
||||
sqrt_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -169,15 +177,17 @@ Status SqrtGradModel(AbstractContext* ctx,
|
||||
// This should return [nullptr, 1].
|
||||
Status IdentityNGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
tape->Watch(ToId(inputs[1]));
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::IdentityN(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
|
||||
|
||||
@ -195,126 +205,181 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = - inputs[0]
|
||||
// return grad(y, {inputs[0]})
|
||||
Status NegGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto neg_output : neg_outputs) {
|
||||
neg_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] - inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status SubGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> sub_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(sub_outputs),
|
||||
"Sub")); // Compute x-y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sub_output : sub_outputs) {
|
||||
sub_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(mul_outputs),
|
||||
"Mul")); // Compute x*y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(mul_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto mul_output : mul_outputs) {
|
||||
mul_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
// Computes
|
||||
// y = log(1 + inputs[0])
|
||||
// return grad(y, {inputs[0]})
|
||||
Status Log1pGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> log1p_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(log1p_outputs),
|
||||
"Log1p")); // Compute log(1 + x).
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(log1p_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto log1p_output : log1p_outputs) {
|
||||
log1p_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
// Computes
|
||||
// y = inputs[0] / inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status DivNoNanGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> div_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::DivNoNan(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(div_outputs),
|
||||
"DivNoNan")); // Compute x / y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
// Runs `model` maybe wrapped in a function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
std::vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(div_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto div_output : div_outputs) {
|
||||
div_output->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -369,7 +434,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -409,18 +474,15 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = exp(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
Status s =
|
||||
RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -453,18 +515,15 @@ TEST_P(CppGradients, TestSqrtGrad) {
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = sqrt(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
Status s =
|
||||
RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -522,7 +581,7 @@ TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
EXPECT_EQ(outputs[0], nullptr);
|
||||
@ -536,6 +595,262 @@ TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestNegGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = - x
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, -1.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSubGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x - y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, -1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x * y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 2.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestLog1pGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = log(1 + x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
Status s =
|
||||
RunModel(Log1pGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x / y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s = RunModel(DivNoNanGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, -0.25, 0.001);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSetAttrString) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -575,7 +890,7 @@ TEST_P(CppGradients, TestSetAttrString) {
|
||||
int num_retvals = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
GradientRegistry registry;
|
||||
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
|
||||
&num_retvals, &forward_op, tape.get(), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
@ -224,8 +225,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
input->DataType(), shape, &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -314,4 +317,4 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
} // namespace tensorflow
|
||||
|
@ -21,14 +21,18 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -124,14 +128,21 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// Returns the device placement policy for the current thread.
|
||||
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||
|
||||
// Configure graph collection in RunMetadata.
|
||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||
|
||||
// Return the collected RunMetadata. This method will transfer the ownership
|
||||
// to the caller.
|
||||
virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are legacy features in TF Eager Runtime.
|
||||
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||
// Following are features in current TF Eager Runtime.
|
||||
// TODO(tfrt-devs): Figure out a way to deprecate following features after
|
||||
// migrated to TFRT.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Clear pending nodes in thread executors and kernel caches.
|
||||
@ -149,8 +160,33 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
// Configure graph collection in RunMetadata.
|
||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are helper functions to assist integrating TFRT with current
|
||||
// TF eager runtime.
|
||||
// TODO(b/172877902): These helper functions are currently used to support
|
||||
// PyFuncOp on TFRT, and might be useful for ops that directly use low
|
||||
// level TF APIs. Remove/replace the following functions when TFRT native
|
||||
// ops are implemented.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Create an abstract tensor handle from tensorflow::Tensor.
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
|
||||
tensorflow::Tensor& t, const char* d_name) = 0;
|
||||
|
||||
// Convert a TFRT TensorHandle to tensorflow::TensorHandle.
|
||||
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
|
||||
ImmediateExecutionTensorHandle* handle) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Distributed runtime related functions.
|
||||
//===--------------------------------------------------------------------===//
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// Set a distributed manager that helps set up, update, and check liveness
|
||||
// of member tasks in the cluster.
|
||||
virtual void SetDistributedManager(
|
||||
std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
|
||||
|
||||
virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
|
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class ImmediateExecutionContext;
|
||||
class ServerDef;
|
||||
|
||||
class ImmediateExecutionDistributedManager {
|
||||
public:
|
||||
virtual ~ImmediateExecutionDistributedManager() {}
|
||||
|
||||
// Set up distributed execution environment on local and remote tasks.
|
||||
// When `reset_context` is true, initialize new cluster context state based on
|
||||
// cluster configurations provided in `server_def`; otherwise, update existing
|
||||
// context state with the provided `server_def`.
|
||||
// Contexts created on remote tasks will be considered stale and garbage
|
||||
// collected after `keep_alive_secs` of inactivity.
|
||||
virtual Status SetOrUpdateServerDef(const ServerDef& server_def,
|
||||
bool reset_context,
|
||||
int keep_alive_secs) = 0;
|
||||
|
||||
// Check if the remote task is alive.
|
||||
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||
bool* is_alive) = 0;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/util/abstract_stack_trace.h"
|
||||
#include "tensorflow/core/util/managed_stack_trace.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
@ -48,10 +48,10 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Set stack trace to be used for potential async error reporting.
|
||||
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
|
||||
virtual void SetStackTrace(ManagedStackTrace stack_trace) = 0;
|
||||
|
||||
// Returns the stack trace set by `SetStackTrace` if exists.
|
||||
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
|
||||
virtual absl::optional<ManagedStackTrace> GetStackTrace() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
|
@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
virtual const char* DeviceName(Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(Status* status) const = 0;
|
||||
// Returns the device type which created the handle.
|
||||
virtual const char* DeviceType(Status* status) const = 0;
|
||||
// Returns the device ID which created the handle.
|
||||
virtual int DeviceId(Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -43,6 +44,11 @@ class CppGradients
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
|
||||
// low tolerances
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -76,6 +76,7 @@ cc_library(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
class DeviceThread {
|
||||
public:
|
||||
// Starts a background thread waiting for `StartExecute`.
|
||||
explicit DeviceThread(const std::string& device)
|
||||
explicit DeviceThread(const std::string& device, const bool is_async)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
// If the context's default exector is set to async, re-using that in
|
||||
@ -67,7 +67,7 @@ class DeviceThread {
|
||||
//
|
||||
// TODO(allenl): We should have an async API that works with the
|
||||
// parallel device.
|
||||
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||
executor_(TFE_NewExecutor(is_async)),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
}
|
||||
}
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
|
||||
const bool is_async)
|
||||
: underlying_devices_(devices) {
|
||||
device_threads_.reserve(devices.size());
|
||||
for (int device_index = 0; device_index < devices.size(); ++device_index) {
|
||||
device_threads_.emplace_back(
|
||||
new DeviceThread(devices[device_index].c_str()));
|
||||
new DeviceThread(devices[device_index].c_str(), is_async));
|
||||
}
|
||||
}
|
||||
|
||||
@ -327,6 +328,17 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
|
||||
return Execute(context, inputs, operation_name, attributes,
|
||||
expected_output_shapes, status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
const std::vector<PartialTensorShape>& expected_output_shapes,
|
||||
TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
@ -343,7 +355,7 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
}
|
||||
device_thread->StartExecute(context, operation_name,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
expected_output_shapes.size());
|
||||
}
|
||||
StatusPtr first_bad_status(nullptr);
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
@ -385,8 +397,15 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (expected_output_shapes[i].IsFullyDefined()) {
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components),
|
||||
absl::Span<const int64>(expected_output_shapes[i].dim_sizes()),
|
||||
status));
|
||||
} else {
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
@ -395,9 +414,27 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
|
||||
TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return std::unique_ptr<ParallelTensor>(
|
||||
new ParallelTensor(parallel_device, std::move(components), shape, dtype));
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
std::vector<int64> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
@ -405,11 +442,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
// Verify that the TensorHandle's shape matches all of the component shapes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
int64 tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
@ -418,17 +454,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
return FromTensorHandles(parallel_device, std::move(components),
|
||||
absl::Span<const int64>(shape), status);
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
@ -49,7 +50,10 @@ class DeviceThread;
|
||||
// placed on each underlying device.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
// Eager async execution is only supported when remote eager is not in use
|
||||
// (b/157523095).
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices,
|
||||
const bool is_async = false);
|
||||
|
||||
~ParallelDevice();
|
||||
|
||||
@ -90,6 +94,15 @@ class ParallelDevice {
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Accepts inferred shapes for outputs, which if fully defined will avoid
|
||||
// querying the shapes of the underlying TensorHandles. This allows async
|
||||
// computation to continue without blocking.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
const std::vector<PartialTensorShape>& expected_output_shapes,
|
||||
TF_Status* status) const;
|
||||
|
||||
private:
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
@ -114,10 +127,15 @@ class ParallelDevice {
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
// devices of a ParallelDevice. Inspects `components` to determine a shape.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
// Uses the provided shape without additional checks, which avoids blocking.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
@ -129,10 +147,10 @@ class ParallelTensor {
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
absl::Span<const int64> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
shape_(shape.begin(), shape.end()),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
|
@ -80,5 +80,41 @@ TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::vector<std::string> devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
ParallelDevice parallel_device(std::move(devices));
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
auto outputs = parallel_device.Execute(
|
||||
context.get(), std::vector<ParallelTensor*>(), "VarHandleOp",
|
||||
TFE_OpGetAttrs(handle_op.get()), {PartialTensorShape({})}, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
|
||||
EXPECT_EQ(0, handles[0]->shape().size());
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
205
tensorflow/c/eager/unified_api_test.cc
Normal file
205
tensorflow/c/eager/unified_api_test.cc
Normal file
@ -0,0 +1,205 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
class UnifiedAPI
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
public:
|
||||
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
|
||||
bool UseFunction() const { return std::get<2>(GetParam()); }
|
||||
};
|
||||
|
||||
// Checks that inputs[0] is a scalar.
|
||||
Status TestScalarShape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||
if (shape.dims() != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Tensor expected to have scalar shape found rank: ", shape.dims());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestTensorShapeScalar) {
|
||||
if (UseFunction() && UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
Status s = RunModel(TestScalarShape, ctx.get(),
|
||||
/*inputs=*/{x.get()},
|
||||
/*outputs=*/{},
|
||||
/*use_function=*/UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
// Checks that inputs[0] is a matrix with shape 2x4.
|
||||
Status TestTensorShape2x4(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||
if (shape.dims() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Tensor expected to have rank 2 found rank: ", shape.dims());
|
||||
}
|
||||
int64 dim_sizes[] = {2, 4};
|
||||
for (int i = 0; i < shape.dims(); i++) {
|
||||
if (shape.dim_size(i) != dim_sizes[i]) {
|
||||
return errors::InvalidArgument("Dim ", i, " expected to be of size ",
|
||||
dim_sizes[i],
|
||||
" found: ", shape.dim_size(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestTensorShape2x4) {
|
||||
if (UseFunction() && UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
|
||||
int64 dim_sizes[] = {2, 4};
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
Status s = RunModel(TestTensorShape2x4, ctx.get(),
|
||||
/*inputs=*/{x.get()},
|
||||
/*outputs=*/{},
|
||||
/*use_function=*/UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
|
||||
if (!UseFunction()) {
|
||||
GTEST_SKIP() << "Tracing only test.";
|
||||
}
|
||||
if (UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||
PartialTensorShape shape;
|
||||
Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||
DT_FLOAT, shape, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
PartialTensorShape shape;
|
||||
Status s = x->Shape(&shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_TRUE(shape.unknown_rank());
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestPartialShapeTracing) {
|
||||
if (!UseFunction()) {
|
||||
GTEST_SKIP() << "Tracing only test.";
|
||||
}
|
||||
if (UseMlir()) {
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||
PartialTensorShape shape;
|
||||
int64 dim_sizes[] = {2, -1};
|
||||
Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||
DT_FLOAT, shape, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
PartialTensorShape shape;
|
||||
Status s = x->Shape(&shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_FALSE(shape.unknown_rank());
|
||||
|
||||
ASSERT_EQ(2, shape.dim_size(0));
|
||||
ASSERT_EQ(-1, shape.dim_size(1));
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
161
tensorflow/c/eager/unified_api_testutil.cc
Normal file
161
tensorflow/c/eager/unified_api_testutil.cc
Normal file
@ -0,0 +1,161 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), shape, &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Runs `model` maybe wrapped in a function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
std::vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
std::vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs)));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
std::vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
|
||||
eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
61
tensorflow/c/eager/unified_api_testutil.h
Normal file
61
tensorflow/c/eager/unified_api_testutil.h
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Builds and returns a `TracingContext` using the default tracing impl.
|
||||
AbstractContext* BuildFunction(const char* fn_name);
|
||||
|
||||
// Creates parameters (placeholders) in the tracing `ctx` using the shape and
|
||||
// dtype of `inputs`.
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params);
|
||||
|
||||
// A callable that takes tensor inputs and returns zero or more tensor outputs.
|
||||
using Model = std::function<Status(AbstractContext*,
|
||||
absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>)>;
|
||||
|
||||
// Runs `model` maybe wrapped in a function call op. This can be thought as
|
||||
// being equivalent to the following python code.
|
||||
//
|
||||
// if use_function:
|
||||
// outputs = tf.function(model)(inputs)
|
||||
// else:
|
||||
// outputs = model(inputs)
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
// Get a Scalar TensorHandle with given float value.
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions.
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
@ -182,9 +182,8 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
||||
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
||||
|
||||
std::string cacheKey(scheme);
|
||||
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
||||
if (scheme == "file") {
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
|
||||
namenode = "";
|
||||
} else if (scheme == "viewfs") {
|
||||
char* defaultFS = nullptr;
|
||||
libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS);
|
||||
@ -200,24 +199,27 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
||||
// The default NameNode configuration will be used (from the XML
|
||||
// configuration files). See:
|
||||
// https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, "default");
|
||||
namenode = "default";
|
||||
} else if (scheme == "har") {
|
||||
std::string path_har = path;
|
||||
SplitArchiveNameAndPath(&path_har, &namenode, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
|
||||
cacheKey += namenode;
|
||||
} else {
|
||||
libhdfs->hdfsBuilderSetNameNode(
|
||||
builder, namenode.empty() ? "default" : namenode.c_str());
|
||||
cacheKey += namenode;
|
||||
if (namenode.empty()) {
|
||||
namenode = "default";
|
||||
}
|
||||
}
|
||||
cacheKey += namenode;
|
||||
|
||||
absl::MutexLock l(&hadoop_file->connection_cache_lock);
|
||||
if (hadoop_file->connection_cache.find(cacheKey) ==
|
||||
hadoop_file->connection_cache.end()) {
|
||||
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
||||
libhdfs->hdfsBuilderSetNameNode(
|
||||
builder, namenode.empty() ? nullptr : namenode.c_str());
|
||||
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
||||
if (cacheFs == nullptr) {
|
||||
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
||||
TF_SetStatusFromIOError(status, TF_ABORTED, strerror(errno));
|
||||
return cacheFs;
|
||||
}
|
||||
hadoop_file->connection_cache[cacheKey] = cacheFs;
|
||||
|
@ -21,9 +21,14 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Add;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Div;
|
||||
using tensorflow::ops::DivNoNan;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::Neg;
|
||||
using tensorflow::ops::OnesLike;
|
||||
using tensorflow::ops::SqrtGrad;
|
||||
|
||||
namespace tensorflow {
|
||||
@ -201,6 +206,204 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
AttrBuilder forward_attrs;
|
||||
};
|
||||
|
||||
class NegGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a Neg op Y = -X, the gradients are:
|
||||
*
|
||||
* dX = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(1);
|
||||
std::string name = "Neg_Grad";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~NegGradientFunction() override {}
|
||||
};
|
||||
|
||||
class SubGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a Sub op A-B, the gradients are:
|
||||
*
|
||||
* dA = U
|
||||
* dB = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Grad for A
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
(*grad_outputs)[0]->Ref();
|
||||
|
||||
// Grad for B
|
||||
// negate the upstream grad
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
std::string name = "Neg_Sub_Grad_B";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(neg_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = neg_outputs[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~SubGradientFunction() override {}
|
||||
};
|
||||
|
||||
class MulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MulGradientFunction(vector<AbstractTensorHandle*> f_inputs)
|
||||
: forward_inputs(f_inputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a mul op A*B, the gradients are:
|
||||
*
|
||||
* dA = U * B
|
||||
* dB = A * U
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
|
||||
// Gradient for A
|
||||
std::string name = "Mul_Grad_A";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {upstream_grad, forward_inputs[1]},
|
||||
absl::MakeSpan(mul_outputs), name.c_str()));
|
||||
(*grad_outputs)[0] = mul_outputs[0];
|
||||
|
||||
// Gradient for B
|
||||
name = "Mul_Grad_B";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {forward_inputs[0], upstream_grad},
|
||||
absl::MakeSpan(mul_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = mul_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~MulGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
};
|
||||
|
||||
class Log1pGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit Log1pGradientFunction(vector<AbstractTensorHandle*> f_inputs)
|
||||
: forward_inputs(f_inputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// TODO(vnvo2409): Add control dependency
|
||||
/* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are:
|
||||
*
|
||||
* dX = U / (1 + X)
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* X = forward_inputs[0];
|
||||
|
||||
grad_outputs->resize(1);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
// Calculate conjugate of X
|
||||
std::string name = "Conj_Log1p_Grad_X";
|
||||
TF_RETURN_IF_ERROR(
|
||||
Conj(ctx->ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Conj_X = temp_outputs[0];
|
||||
|
||||
// Creates Ones
|
||||
name = "OnesLike_Log1p_Grad_X";
|
||||
TF_RETURN_IF_ERROR(OnesLike(ctx->ctx, {Conj_X},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Ones_X = temp_outputs[0];
|
||||
|
||||
name = "Add_Log1p_Grad_X";
|
||||
// Calculate 1 + Conj(X)
|
||||
TF_RETURN_IF_ERROR(Add(ctx->ctx, {Ones_X, Conj_X},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Conj_XP1 = temp_outputs[0];
|
||||
|
||||
name = "Div_Log1p_Grad_X";
|
||||
// Calculate U / (1 + Conj(X))
|
||||
TF_RETURN_IF_ERROR(Div(ctx->ctx, {upstream_grad, Conj_XP1},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = temp_outputs[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~Log1pGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
};
|
||||
|
||||
class DivNoNanGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit DivNoNanGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_inputs(f_inputs), forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// TODO(vnvo2409): Add shape broadcasting
|
||||
/* Given upstream grad U and a Div op: Z = X/Y, the gradients are:
|
||||
*
|
||||
* dX = U / Y
|
||||
* dY = -U*X / Y^2 = (X/Y) * -U / Y = -U*Z / Y
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* Y = forward_inputs[1];
|
||||
AbstractTensorHandle* Z = forward_outputs[0];
|
||||
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
// Calculate dX = U / Y
|
||||
std::string name = "Div_Grad_X";
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {upstream_grad, Y},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = temp_outputs[0];
|
||||
|
||||
// Calculate dY = -U*Z / Y
|
||||
name = "Neg_Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(Neg(ctx->ctx, {upstream_grad},
|
||||
absl::MakeSpan(temp_outputs), name.c_str())); // -U
|
||||
AbstractTensorHandle* MinusU = temp_outputs[0];
|
||||
|
||||
name = "Mul_Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {MinusU, Z}, absl::MakeSpan(temp_outputs),
|
||||
name.c_str())); // -U*Z
|
||||
AbstractTensorHandle* UZ = temp_outputs[0];
|
||||
|
||||
name = "Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {UZ, Y}, absl::MakeSpan(temp_outputs),
|
||||
name.c_str())); // -U*Z / Y
|
||||
|
||||
(*grad_outputs)[1] = temp_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~DivNoNanGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
@ -239,5 +442,50 @@ BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new NegGradientFunction;
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new SubGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* MulRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new MulGradientFunction(op.inputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* Log1pRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new Log1pGradientFunction(op.inputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new DivNoNanGradientFunction(op.inputs, op.outputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -24,6 +24,11 @@ BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* Log1pRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -25,7 +25,7 @@ TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
|
||||
parent_op_(parent_op),
|
||||
tape_(tape),
|
||||
registry_(registry) {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
// parent_op_->Ref();
|
||||
}
|
||||
void TapeOperation::Release() {
|
||||
@ -33,7 +33,7 @@ void TapeOperation::Release() {
|
||||
delete this;
|
||||
}
|
||||
TapeOperation::~TapeOperation() {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
// parent_op->Unref();
|
||||
}
|
||||
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
|
||||
|
43
tensorflow/c/experimental/op_handler/BUILD
Normal file
43
tensorflow/c/experimental/op_handler/BUILD
Normal file
@ -0,0 +1,43 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "internal_test",
|
||||
srcs = ["internal_test.cc"],
|
||||
deps = [
|
||||
":internal",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "internal",
|
||||
srcs = ["internal.cc"],
|
||||
hdrs = ["internal.h"],
|
||||
deps = [
|
||||
":wrapper_operation",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "wrapper_operation",
|
||||
srcs = ["wrapper_operation.cc"],
|
||||
hdrs = ["wrapper_operation.h"],
|
||||
deps = ["//tensorflow/c/eager:abstract_operation"],
|
||||
)
|
79
tensorflow/c/experimental/op_handler/internal.cc
Normal file
79
tensorflow/c/experimental/op_handler/internal.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
|
||||
|
||||
#include "tensorflow/c/experimental/op_handler/internal.h"
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
OpHandlerContext::OpHandlerContext(AbstractContext* parent_ctx)
|
||||
: AbstractContext(kOpHandler), parent_ctx_(parent_ctx) {}
|
||||
OpHandlerContext::~OpHandlerContext() {}
|
||||
void OpHandlerContext::Release() { delete this; }
|
||||
Status OpHandlerContext::RegisterFunction(AbstractFunction* function) {
|
||||
return parent_ctx_->RegisterFunction(function);
|
||||
}
|
||||
|
||||
Status OpHandlerContext::RemoveFunction(const string& function) {
|
||||
return parent_ctx_->RemoveFunction(function);
|
||||
}
|
||||
|
||||
void OpHandlerContext::set_default_handler(OpHandler* handler) {
|
||||
handler->Ref();
|
||||
default_handler_.reset(handler);
|
||||
}
|
||||
|
||||
OpHandlerOperation* OpHandlerContext::CreateOperation() {
|
||||
OpHandlerOperation* result =
|
||||
new OpHandlerOperation(parent_ctx_->CreateOperation());
|
||||
if (default_handler_ != nullptr) {
|
||||
result->set_handler(default_handler_.get());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
OpHandlerOperation::OpHandlerOperation(AbstractOperation* parent_op)
|
||||
: WrapperOperation(parent_op, kOpHandler) {}
|
||||
|
||||
OpHandler* OpHandlerOperation::get_handler() { return handler_.get(); }
|
||||
|
||||
void OpHandlerOperation::set_handler(OpHandler* handler) {
|
||||
if (handler != nullptr) {
|
||||
handler->Ref();
|
||||
}
|
||||
handler_.reset(handler);
|
||||
}
|
||||
|
||||
Status OpHandlerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
if (handler_ == nullptr) {
|
||||
return WrapperOperation::Execute(retvals, num_retvals);
|
||||
} else {
|
||||
return handler_->Execute(this, retvals, num_retvals);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
117
tensorflow/c/experimental/op_handler/internal.h
Normal file
117
tensorflow/c/experimental/op_handler/internal.h
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpHandlerOperation;
|
||||
|
||||
// Op handlers are a convenient way to intercept and transform computation.
|
||||
//
|
||||
// The implementation is currently experimental and incomplete, but aims
|
||||
// eventually to support tracing and replay of function bodies, gradients
|
||||
// through copy operations, and a variety of hooks for things like debug
|
||||
// strings. A public C API for op handlers is planned.
|
||||
class OpHandler : public core::RefCounted {
|
||||
public:
|
||||
// Called on operation->Execute when operation->get_handler() == this.
|
||||
//
|
||||
// Allows the handler to customize or inspect `operation`'s execution.
|
||||
virtual Status Execute(OpHandlerOperation* operation,
|
||||
absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
// Creates a new handler by merging this handler with `next_handler`.
|
||||
//
|
||||
// The new handler is expected to transform operations first with this handler
|
||||
// and then execute the resulting operations on `next_handler` (by calling
|
||||
// `OpHandlerOperation::set_handler` and passing `next_handler`). If this is
|
||||
// not possible then the merge operation should fail.
|
||||
virtual Status Merge(OpHandler* next_handler,
|
||||
core::RefCountPtr<OpHandler>& merged_handler) = 0;
|
||||
};
|
||||
|
||||
// Keeps some handler-specific metadata, but otherwise wraps a single
|
||||
// AbstractOperation in the underlying context. The operation is created, its
|
||||
// attributes set, etc., and at execution time it is presented to its handler,
|
||||
// which may choose to execute it or simply inspect it and do something else.
|
||||
//
|
||||
// This is somewhat different than the Context approach, where the operation's
|
||||
// construction is streamed through each layered Context. The streaming approach
|
||||
// would require a much larger op handler public API, one function pointer per
|
||||
// attribute type, and there is some ambiguity before an op is finalized about
|
||||
// whether it should be presented as-is to handlers (regular operations) or
|
||||
// replayed (function calls and control flow operations).
|
||||
class OpHandlerOperation : public WrapperOperation {
|
||||
public:
|
||||
explicit OpHandlerOperation(AbstractOperation*);
|
||||
OpHandler* get_handler();
|
||||
void set_handler(OpHandler* handler);
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
|
||||
protected:
|
||||
core::RefCountPtr<OpHandler> handler_;
|
||||
};
|
||||
|
||||
// A context which allows a default handler to be set for new operations. It
|
||||
// otherwise defers to the context it wraps.
|
||||
//
|
||||
// TODO(allenl): A stack of contexts and a stack of handlers look pretty similar
|
||||
// in some ways. Having each handler be its own context seems almost doable,
|
||||
// with things like copy operations and function/control flow replay being
|
||||
// somewhat tricky (since they should be generated at the top of the handler
|
||||
// stack and "caught" at the bottom). After handlers have evolved for a bit we
|
||||
// should re-evaluate whether the handler+context concepts can be merged.
|
||||
class OpHandlerContext : public AbstractContext {
|
||||
public:
|
||||
explicit OpHandlerContext(AbstractContext*);
|
||||
void Release() override;
|
||||
OpHandlerOperation* CreateOperation() override;
|
||||
Status RegisterFunction(AbstractFunction*) override;
|
||||
Status RemoveFunction(const string&) override;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kOpHandler;
|
||||
}
|
||||
~OpHandlerContext() override;
|
||||
|
||||
void set_default_handler(OpHandler* handler);
|
||||
|
||||
private:
|
||||
AbstractContext* parent_ctx_; // Not owned.
|
||||
core::RefCountPtr<OpHandler> default_handler_;
|
||||
};
|
||||
|
||||
class ReleaseOpHandlerOperation {
|
||||
public:
|
||||
void operator()(OpHandlerOperation* operation) { operation->Release(); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<OpHandlerOperation, ReleaseOpHandlerOperation>
|
||||
OpHandlerOperationPtr;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
102
tensorflow/c/experimental/op_handler/internal_test.cc
Normal file
102
tensorflow/c/experimental/op_handler/internal_test.cc
Normal file
@ -0,0 +1,102 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/op_handler/internal.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TestOpHandler : public OpHandler {
|
||||
public:
|
||||
TestOpHandler() : last_operation_(new std::string("")) {}
|
||||
Status Execute(OpHandlerOperation* operation,
|
||||
absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override {
|
||||
CHECK(operation->get_handler() == this);
|
||||
*last_operation_ = operation->Name();
|
||||
operation->set_handler(next_handler_.get());
|
||||
return operation->Execute(retvals, num_retvals);
|
||||
}
|
||||
Status Merge(OpHandler* next_handler,
|
||||
core::RefCountPtr<OpHandler>& merged_handler) override {
|
||||
merged_handler.reset(new TestOpHandler(next_handler, last_operation_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
core::RefCountPtr<OpHandler> next_handler_ = nullptr;
|
||||
// Shared between merged handlers of this type.
|
||||
std::shared_ptr<std::string> last_operation_;
|
||||
|
||||
private:
|
||||
TestOpHandler(OpHandler* next_handler,
|
||||
std::shared_ptr<std::string> last_operation)
|
||||
: next_handler_(next_handler), last_operation_(last_operation) {
|
||||
next_handler->Ref();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(INTERNAL_TEST, UseOpHandler) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_ExecutionContext, decltype(&TF_DeleteExecutionContext)>
|
||||
c_ctx(TF_NewEagerExecutionContext(opts.get(), status.get()),
|
||||
TF_DeleteExecutionContext);
|
||||
OpHandlerContext ctx(unwrap(c_ctx.get()));
|
||||
core::RefCountPtr<TestOpHandler> outer_handler(new TestOpHandler());
|
||||
core::RefCountPtr<TestOpHandler> inner_handler(new TestOpHandler());
|
||||
ctx.set_default_handler(outer_handler.get());
|
||||
OpHandlerOperationPtr op(ctx.CreateOperation());
|
||||
Status s = op->Reset("NoOp", "");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
std::vector<AbstractTensorHandle*> retvals;
|
||||
int num_retvals = 0;
|
||||
EXPECT_EQ("", *outer_handler->last_operation_);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
|
||||
*outer_handler->last_operation_ = "";
|
||||
EXPECT_EQ("", *inner_handler->last_operation_);
|
||||
|
||||
// This op executes on both handlers, changing the state of `inner_handler`
|
||||
// since the handler has decided to preserve that state across merges.
|
||||
core::RefCountPtr<OpHandler> merged;
|
||||
s = inner_handler->Merge(outer_handler.get(), merged);
|
||||
ctx.set_default_handler(merged.get());
|
||||
op.reset(ctx.CreateOperation());
|
||||
s = op->Reset("NoOp", "");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
EXPECT_EQ("NoOp", *inner_handler->last_operation_);
|
||||
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
|
||||
|
||||
inner_handler.reset();
|
||||
outer_handler.reset();
|
||||
op.reset(ctx.CreateOperation());
|
||||
s = op->Reset("NoOp", "");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
120
tensorflow/c/experimental/op_handler/wrapper_operation.cc
Normal file
120
tensorflow/c/experimental/op_handler/wrapper_operation.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
WrapperOperation::WrapperOperation(AbstractOperation* parent_op,
|
||||
AbstractOperationKind kind)
|
||||
: AbstractOperation(kind), parent_op_(parent_op) {
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
// parent_op_->Ref();
|
||||
}
|
||||
void WrapperOperation::Release() {
|
||||
parent_op_->Release();
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
delete this;
|
||||
}
|
||||
|
||||
Status WrapperOperation::Reset(const char* op, const char* raw_device_name) {
|
||||
return parent_op_->Reset(op, raw_device_name);
|
||||
}
|
||||
const string& WrapperOperation::Name() const { return parent_op_->Name(); }
|
||||
const string& WrapperOperation::DeviceName() const {
|
||||
return parent_op_->DeviceName();
|
||||
}
|
||||
Status WrapperOperation::SetDeviceName(const char* name) {
|
||||
return parent_op_->SetDeviceName(name);
|
||||
}
|
||||
Status WrapperOperation::AddInput(AbstractTensorHandle* input) {
|
||||
return parent_op_->AddInput(input);
|
||||
}
|
||||
Status WrapperOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
return parent_op_->AddInputList(inputs);
|
||||
}
|
||||
Status WrapperOperation::SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) {
|
||||
return parent_op_->SetAttrString(attr_name, data, length);
|
||||
}
|
||||
Status WrapperOperation::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
return parent_op_->SetAttrInt(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFloat(const char* attr_name, float value) {
|
||||
return parent_op_->SetAttrFloat(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrBool(const char* attr_name, bool value) {
|
||||
return parent_op_->SetAttrBool(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrType(const char* attr_name, DataType value) {
|
||||
return parent_op_->SetAttrType(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims, const int num_dims) {
|
||||
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) {
|
||||
return parent_op_->SetAttrFunction(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFunctionName(const char* attr_name,
|
||||
const char* value, size_t length) {
|
||||
return parent_op_->SetAttrFunctionName(attr_name, value, length);
|
||||
}
|
||||
Status WrapperOperation::SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) {
|
||||
return parent_op_->SetAttrTensor(attr_name, tensor);
|
||||
}
|
||||
Status WrapperOperation::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) {
|
||||
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFloatList(const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
return parent_op_->SetAttrIntList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrTypeList(const char* attr_name,
|
||||
const DataType* values,
|
||||
int num_values) {
|
||||
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims, int num_values) {
|
||||
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFunctionList(
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||
return parent_op_->SetAttrFunctionList(attr_name, values);
|
||||
}
|
||||
AbstractOperation* WrapperOperation::GetBackingOperation() {
|
||||
return parent_op_;
|
||||
}
|
||||
Status WrapperOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
return parent_op_->Execute(retvals, num_retvals);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
74
tensorflow/c/experimental/op_handler/wrapper_operation.h
Normal file
74
tensorflow/c/experimental/op_handler/wrapper_operation.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forwards all of the AbstractOperation's methods to its wrapped operation.
|
||||
//
|
||||
// Useful as a base class to default to forwarding while adding some
|
||||
// customization.
|
||||
class WrapperOperation : public AbstractOperation {
|
||||
public:
|
||||
explicit WrapperOperation(AbstractOperation*, AbstractOperationKind kind);
|
||||
void Release() override;
|
||||
Status Reset(const char* op, const char* raw_device_name) override;
|
||||
const string& Name() const override;
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values) override;
|
||||
AbstractOperation* GetBackingOperation();
|
||||
|
||||
private:
|
||||
AbstractOperation* parent_op_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
@ -81,5 +81,17 @@ Status ExpandDims(AbstractContext* ctx,
|
||||
return op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status OnesLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name));
|
||||
TF_RETURN_IF_ERROR(op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
return op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -42,6 +42,10 @@ Status ExpandDims(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status OnesLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -44,8 +44,18 @@ Status Conj(AbstractContext* ctx,
|
||||
if (DataTypeIsFloating(BaseType(dtype)) ||
|
||||
DataTypeIsInteger(BaseType(dtype))) {
|
||||
TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
|
||||
} else if (DataTypeIsComplex(BaseType(dtype)) ||
|
||||
BaseType(dtype) == DT_VARIANT) {
|
||||
AbstractOperationPtr conj_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(conj_op->Reset("Conj", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(conj_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(conj_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(conj_op->Execute(outputs, &num_retvals));
|
||||
} else {
|
||||
return errors::Unimplemented("Conj does not support complex types yet.");
|
||||
return errors::InvalidArgument(
|
||||
"Expected numeric or variant tensor, got dtype ", dtype);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -118,6 +128,19 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr div_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(div_op->Reset("Div", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(div_op->Execute(outputs, &num_retvals)); // z = x / y
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
@ -172,5 +195,18 @@ Status SqrtGrad(AbstractContext* ctx,
|
||||
return s;
|
||||
}
|
||||
|
||||
Status Log1p(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr log1p_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(log1p_op->Reset("Log1p", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(log1p_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(log1p_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = log1p_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -44,6 +44,9 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
@ -59,6 +62,10 @@ Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Log1p(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -146,6 +146,7 @@ cc_library(
|
||||
":tf_signature_def_function",
|
||||
":variable",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -343,7 +343,8 @@ Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<TFConcreteFunction> out;
|
||||
TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn,
|
||||
obj_graph, objects, &out));
|
||||
revived->concrete_functions[create_resource_fn->node_id] = std::move(out);
|
||||
revived->concrete_functions.Insert(std::move(out),
|
||||
create_resource_fn->node_id);
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
@ -352,8 +353,6 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
RevivedObjects* revived) {
|
||||
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>* destination_func_map =
|
||||
&revived->concrete_functions;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>*
|
||||
destination_sig_map = &revived->signature_def_functions;
|
||||
|
||||
@ -361,7 +360,7 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
|
||||
int node_id = id_and_func.first;
|
||||
const TFConcreteFunctionRevivalState& func = id_and_func.second;
|
||||
|
||||
if (destination_func_map->find(node_id) != destination_func_map->end()) {
|
||||
if (revived->concrete_functions.Find(node_id)) {
|
||||
// The function has already been initialized in the destination_map,
|
||||
// so we can skip this node. This can occur because we initialize
|
||||
// CreateResource functions before calling this function.
|
||||
@ -371,7 +370,7 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<TFConcreteFunction> out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateConcreteFunction(ctx, func, obj_graph, objects, &out));
|
||||
(*destination_func_map)[node_id] = std::move(out);
|
||||
revived->concrete_functions.Insert(std::move(out), node_id);
|
||||
}
|
||||
|
||||
for (const auto& id_and_func : objects.signature_def_functions) {
|
||||
@ -398,20 +397,16 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
|
||||
for (auto& id_and_resource : objects->restored_resources) {
|
||||
RestoredResourceRevivalState& resource = id_and_resource.second;
|
||||
int create_resource_fn_node = resource.create_resource->node_id;
|
||||
const gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>&
|
||||
revived_functions = revived->concrete_functions;
|
||||
|
||||
const auto& revived_functions_iter =
|
||||
revived_functions.find(create_resource_fn_node);
|
||||
if (revived_functions_iter == revived_functions.end()) {
|
||||
const TFConcreteFunction* create_resource_fn =
|
||||
revived->concrete_functions.Find(create_resource_fn_node);
|
||||
if (create_resource_fn == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"ConcreteFunction at node ", create_resource_fn_node,
|
||||
" should have been initialized prior to being called.");
|
||||
}
|
||||
const TFConcreteFunction& create_resource_fn =
|
||||
*revived_functions_iter->second;
|
||||
ImmediateOpPtr function_op;
|
||||
TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op));
|
||||
TF_RETURN_IF_ERROR(create_resource_fn->MakeCallOp({}, &function_op));
|
||||
TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str()));
|
||||
|
||||
AbstractTensorHandle* resource_handle = nullptr;
|
||||
@ -431,21 +426,6 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to
|
||||
// point to it. If node doesn't exist in `objects`, out is untouched, and an
|
||||
// error status is returned.
|
||||
Status FindConcreteFunction(int node, RevivedObjects* objects,
|
||||
TFConcreteFunction** out) {
|
||||
auto func_iter = objects->concrete_functions.find(node);
|
||||
if (func_iter == objects->concrete_functions.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Failed to find ConcreteFunction with node id ", node,
|
||||
" in revived objects");
|
||||
}
|
||||
*out = func_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status BuildResources(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
PartiallyRevivedObjects* objects,
|
||||
@ -460,22 +440,35 @@ Status BuildResources(ImmediateExecutionContext* ctx,
|
||||
// Check all the functions associated with the resource have already been
|
||||
// initialized in `revived`
|
||||
if (resource_revival_state.create_resource != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindConcreteFunction(resource_revival_state.create_resource->node_id,
|
||||
revived, &create_resource));
|
||||
create_resource = revived->concrete_functions.Find(
|
||||
resource_revival_state.create_resource->node_id);
|
||||
if (create_resource == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"'create_resource' function with node id ",
|
||||
resource_revival_state.create_resource->node_id, " not found");
|
||||
}
|
||||
}
|
||||
|
||||
TFConcreteFunction* initialize = nullptr;
|
||||
if (resource_revival_state.initialize != nullptr) {
|
||||
TF_RETURN_IF_ERROR(FindConcreteFunction(
|
||||
resource_revival_state.initialize->node_id, revived, &initialize));
|
||||
initialize = revived->concrete_functions.Find(
|
||||
resource_revival_state.initialize->node_id);
|
||||
if (initialize == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"'initialize' function with node id ",
|
||||
resource_revival_state.initialize->node_id, " not found");
|
||||
}
|
||||
}
|
||||
|
||||
TFConcreteFunction* destroy_resource = nullptr;
|
||||
if (resource_revival_state.destroy_resource != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindConcreteFunction(resource_revival_state.destroy_resource->node_id,
|
||||
revived, &destroy_resource));
|
||||
destroy_resource = revived->concrete_functions.Find(
|
||||
resource_revival_state.destroy_resource->node_id);
|
||||
if (destroy_resource == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"'destroy_resource' function with node id ",
|
||||
resource_revival_state.destroy_resource->node_id, " not found");
|
||||
}
|
||||
}
|
||||
|
||||
if (resource_revival_state.resource_handle == nullptr) {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
|
||||
@ -29,6 +30,43 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A container for revived saved model objects.
|
||||
//
|
||||
// Most of the objects will be revived from nodes in the object graph, and for
|
||||
// those objects this container provides a map from node id to the revived
|
||||
// objects.
|
||||
//
|
||||
// For objects that have to be revived but are not part of the object graph,
|
||||
// this container provides a place where the objects can be stored so they are
|
||||
// available to the runtime.
|
||||
template <typename T>
|
||||
class RevivedObjectContainer {
|
||||
public:
|
||||
// Insert an object that is not related to a node id. This usually means the
|
||||
// object was not referenced by the object_graph, but is needed by other
|
||||
// objects.
|
||||
void Insert(std::unique_ptr<T> object) {
|
||||
objects_.push_back(std::move(object));
|
||||
}
|
||||
|
||||
// Insert an object that is tied to the given object graph node id.
|
||||
void Insert(std::unique_ptr<T> object, int node_id) {
|
||||
objects_by_id_[node_id] = object.get();
|
||||
Insert(std::move(object));
|
||||
}
|
||||
|
||||
// Find an object by the object graph node id.
|
||||
// Returns nullptr if there is no such object.
|
||||
T* Find(int node_id) {
|
||||
auto it = objects_by_id_.find(node_id);
|
||||
return it == objects_by_id_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<T>> objects_;
|
||||
absl::flat_hash_map<int, T*> objects_by_id_;
|
||||
};
|
||||
|
||||
// RevivedObjects is mainly used as a container for all the "state" owned by
|
||||
// SavedModel. It stores all non-"user object" nodes from a SavedModel
|
||||
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62)
|
||||
@ -37,12 +75,14 @@ namespace tensorflow {
|
||||
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29)
|
||||
// to the revived object of the corresponding type.
|
||||
struct RevivedObjects {
|
||||
// Order of declaration is important here: we want the RestoredResources to be
|
||||
// freed after TFConcreteFunctions, for example.
|
||||
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
||||
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
||||
gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>> concrete_functions;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
|
||||
signature_def_functions;
|
||||
RevivedObjectContainer<TFConcreteFunction> concrete_functions;
|
||||
gtl::FlatMap<int, RestoredResource> restored_resources;
|
||||
gtl::FlatMap<std::string, int> signatures_map;
|
||||
};
|
||||
|
@ -46,8 +46,6 @@ class SavedModelAPI {
|
||||
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
SignatureDefFunction** function) = 0;
|
||||
|
||||
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
|
||||
|
||||
virtual ~SavedModelAPI() = default;
|
||||
};
|
||||
|
||||
|
@ -73,7 +73,6 @@ using FlatTensorFunctionMap =
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
FindSerializedTensorInTrackable(
|
||||
const TrackableObjectGraph::TrackableObject& trackable_object,
|
||||
@ -181,12 +180,11 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
||||
return errors::NotFound("No saved object found at path ", function_path);
|
||||
}
|
||||
|
||||
auto function_iter = revived_objects_.concrete_functions.find(*node);
|
||||
if (function_iter == revived_objects_.concrete_functions.end()) {
|
||||
*function = revived_objects_.concrete_functions.Find(*node);
|
||||
if (*function == nullptr) {
|
||||
return errors::NotFound("No function found at path ", function_path);
|
||||
}
|
||||
|
||||
*function = function_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
@ -211,15 +209,6 @@ Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||
return Status();
|
||||
}
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(revived_objects_.concrete_functions.size());
|
||||
for (auto& index_and_function : revived_objects_.concrete_functions) {
|
||||
result.push_back(index_and_function.second.get());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
|
||||
Variable** variable) {
|
||||
absl::optional<int> node =
|
||||
@ -263,10 +252,10 @@ Status TFSavedModelAPI::Load(
|
||||
// This occurs in python here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
|
||||
|
||||
// Step 1: For each node in the graph, we should initialize an object of the
|
||||
// For each node in the graph, we should initialize an object of the
|
||||
// corresponding type. For objects that depend on the initialization of other
|
||||
// objects (like functions which capture resources), we will initialize them
|
||||
// in step 2.
|
||||
// later.
|
||||
PartiallyRevivedObjects partially_revived_objects;
|
||||
TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
|
||||
bundle.meta_graph_def(), context, directory, &partially_revived_objects));
|
||||
@ -275,6 +264,22 @@ Status TFSavedModelAPI::Load(
|
||||
TF_RETURN_IF_ERROR(partially_revived_objects.Build(
|
||||
context, bundle.saved_object_graph(), &revived_objects));
|
||||
|
||||
// Revive function library functions as concrete functions without captures.
|
||||
// This is necessary because object graph functions may refer to functions
|
||||
// _not_ in the object graph: A while loop, for example, will create two
|
||||
// auxiliary `while_cond` and `while_body` functions that are only present in
|
||||
// the graph def function library.
|
||||
for (const FunctionDef& function :
|
||||
bundle.meta_graph_def().graph_def().library().function()) {
|
||||
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||
TF_RETURN_IF_ERROR(TFConcreteFunction::Create(/*function_def=*/&function,
|
||||
/*captures=*/{},
|
||||
/*metadata=*/{},
|
||||
/*ctx=*/context,
|
||||
/*out=*/&concrete_function));
|
||||
revived_objects.concrete_functions.Insert(std::move(concrete_function));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreCheckpoint(&bundle, revived_objects, directory, context));
|
||||
|
||||
|
@ -66,8 +66,6 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
ImmediateExecutionContext* context,
|
||||
std::unique_ptr<TFSavedModelAPI>* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPI() override = default;
|
||||
|
||||
Status GetVariable(const std::string& variable_path, Variable** variable);
|
||||
|
@ -122,9 +122,4 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{
|
||||
tensorflow::unwrap(model)->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -524,6 +524,62 @@ TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadSavedModelWithWhileLoop) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(),
|
||||
"c/experimental/saved_model/internal/testdata/SimpleWhileLoop");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_ConcreteFunction* while_fn =
|
||||
TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
std::vector<TFE_TensorHandle*> while_fn_inputs;
|
||||
while_fn_inputs.push_back(TestScalarTensorHandle(ctx, 10.0f));
|
||||
|
||||
TFE_Op* while_fn_op = TF_ConcreteFunctionMakeCallOp(
|
||||
while_fn, while_fn_inputs.data(), while_fn_inputs.size(), status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* while_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(while_fn_op, &while_fn_outputs[0], &num_retvals, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(while_fn_outputs[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(TF_NumDims(result), 0);
|
||||
float output_value = *static_cast<float*>(TF_TensorData(result));
|
||||
ASSERT_FLOAT_EQ(output_value, 55); // 10+9+...+1
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(while_fn_outputs[0]);
|
||||
TFE_DeleteOp(while_fn_op);
|
||||
TFE_DeleteTensorHandle(while_fn_inputs[0]);
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
|
||||
::testing::Bool());
|
||||
|
||||
|
@ -12,6 +12,8 @@ py_strict_binary(
|
||||
srcs = ["gen_saved_models.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
@ -21,7 +23,6 @@ py_strict_binary(
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
],
|
||||
)
|
||||
|
||||
@ -29,6 +30,7 @@ py_strict_binary(
|
||||
filegroup(
|
||||
name = "saved_models",
|
||||
srcs = glob([
|
||||
"SimpleWhileLoop/**",
|
||||
"UninitializedVariable/**",
|
||||
]),
|
||||
visibility = [
|
||||
|
BIN
tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/saved_model.pb
vendored
Normal file
BIN
tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.index
vendored
Normal file
BIN
tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.index
vendored
Normal file
Binary file not shown.
Binary file not shown.
@ -30,9 +30,11 @@ import os
|
||||
from tensorflow.python.compat import v2_compat
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import app
|
||||
@ -72,11 +74,32 @@ def _gen_uninitialized_variable(base_dir):
|
||||
to_save, export_dir=os.path.join(base_dir, "UninitializedVariable"))
|
||||
|
||||
|
||||
def _gen_simple_while_loop(base_dir):
|
||||
"""Generates a saved model with a while loop."""
|
||||
|
||||
class Module(module.Module):
|
||||
"""A module with a while loop."""
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
|
||||
def compute(self, value):
|
||||
acc, _ = control_flow_ops.while_loop(
|
||||
cond=lambda acc, i: i > 0,
|
||||
body=lambda acc, i: (acc + i, i - 1),
|
||||
loop_vars=(constant_op.constant(0.0), value))
|
||||
return acc
|
||||
|
||||
to_save = Module()
|
||||
saved_model.save(
|
||||
to_save, export_dir=os.path.join(base_dir, "SimpleWhileLoop"))
|
||||
|
||||
|
||||
def main(args):
|
||||
if len(args) != 2:
|
||||
raise app.UsageError("Expected one argument (base_dir).")
|
||||
_, base_dir = args
|
||||
_gen_uninitialized_variable(base_dir)
|
||||
_gen_simple_while_loop(base_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -100,11 +100,6 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
|
||||
const char* signature_def_key,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||
// The lifetime of the returned list is bound to `model`.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
|
||||
TF_SavedModel* model);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -7,15 +7,36 @@ load(
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = [
|
||||
"stream_executor.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor_hdrs",
|
||||
hdrs = ["stream_executor.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor",
|
||||
srcs = ["stream_executor.cc"],
|
||||
hdrs = ["stream_executor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":stream_executor_internal",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
@ -39,9 +60,11 @@ cc_library(
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/stream_executor:executor_cache",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -44,6 +43,7 @@ using tensorflow::StatusFromTF_Status;
|
||||
|
||||
namespace stream_executor {
|
||||
using tensorflow::StringPiece;
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -188,41 +188,6 @@ port::Status ValidateSEPlatformRegistrationParams(
|
||||
}
|
||||
#undef VALIDATE_MEMBER
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
// Converts SE_EventStatus to Event::Status.
|
||||
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
switch (s) {
|
||||
@ -237,82 +202,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
}
|
||||
}
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
// Converts DeviceMemoryBase to a C struct.
|
||||
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||
@ -321,14 +210,12 @@ SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
device_memory_base.opaque = const_cast<void*>(mem->opaque());
|
||||
device_memory_base.size = mem->size();
|
||||
device_memory_base.payload = mem->payload();
|
||||
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
|
||||
return device_memory_base;
|
||||
}
|
||||
|
||||
DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
|
||||
DeviceMemoryBase base(mem.opaque, mem.size);
|
||||
base.SetPayload(mem.payload);
|
||||
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
|
||||
return base;
|
||||
}
|
||||
|
||||
@ -426,7 +313,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
|
||||
LOG(ERROR) << status.error_message();
|
||||
return absl::nullopt;
|
||||
}
|
||||
// TODO(annarev): validate SP_AllocatorStats.
|
||||
::stream_executor::AllocatorStats stats;
|
||||
stats.num_allocs = c_stats.num_allocs;
|
||||
stats.bytes_in_use = c_stats.bytes_in_use;
|
||||
|
@ -140,8 +140,9 @@ typedef enum SE_EventStatus {
|
||||
// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57
|
||||
typedef struct SP_DeviceMemoryBase {
|
||||
size_t struct_size;
|
||||
void* ext; // free-form data set by plugin
|
||||
void* ext; // Reserved for future use
|
||||
// Platform-dependent value representing allocated memory.
|
||||
// Note that the pointer does not have to be to the virtual address itself.
|
||||
void* opaque;
|
||||
uint64_t size; // Size in bytes of this allocation.
|
||||
uint64_t payload; // Value for plugin's use
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
@ -37,6 +38,10 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
|
||||
// testing).
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
|
||||
// This file implements core stream executor base classes in terms of
|
||||
// the C API defined in stream_executor.h. A class "CSomething" represents a
|
||||
// "Something" that can be manipulated via calls in the C interface.
|
||||
@ -86,5 +91,111 @@ class CPlatform : public Platform {
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
};
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
} // namespace stream_executor
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
@ -24,7 +24,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
// Required for IS_MOBILE_PLATFORM definition
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
// This file forms the basis of a stable ABI for third-party kernel
|
||||
@ -185,6 +191,35 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// This function is only for pluggable device.
|
||||
// It will return nullptr in all other cases.
|
||||
// This function is experimental and subject to change.
|
||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Accessing device stream is not supported on mobile. File a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
||||
"important to you");
|
||||
return nullptr;
|
||||
#else
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"Accessing device stream is not supported for a CPU device.");
|
||||
return nullptr;
|
||||
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"Accessing device stream is only supported for pluggable devices.");
|
||||
return nullptr;
|
||||
} else { // Is a PluggableDevice
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||
cc_ctx->op_device_context()->stream()->implementation());
|
||||
return c_stream->Handle();
|
||||
}
|
||||
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
}
|
||||
|
||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
return cc_ctx->num_inputs();
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
@ -65,6 +66,11 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
|
||||
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
||||
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
||||
|
||||
// TF_InitKernel to do op/kernel registration.
|
||||
// Plugin should implement TF_InitKernel to register kernels. This function
|
||||
// should register all kernels in a plugin.
|
||||
void TF_InitKernel();
|
||||
|
||||
// Allocates a new kernel builder and returns a pointer to it.
|
||||
//
|
||||
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
||||
@ -128,6 +134,16 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
||||
// --------------------------------------------------------------------------
|
||||
// OpKernelContext routines
|
||||
|
||||
// TF_GetStream returns the SP_Stream available in ctx.
|
||||
// This function returns a stream only for devices registered using the
|
||||
// StreamExecutor C API
|
||||
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
||||
// nullptr and set error status in all other cases.
|
||||
// Experimental: this function doesn't have compatibility guarantees and subject
|
||||
// to change at any time.
|
||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// TF_NumInputs returns the number of inputs available in ctx.
|
||||
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
||||
|
||||
|
@ -148,7 +148,7 @@ void RegisterBitcastOpKernel() {
|
||||
<< "Error while registering bitcast kernel";
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
{
|
||||
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU,
|
||||
&BitcastOp_Create, &BitcastOp_Compute,
|
||||
|
@ -49,14 +49,12 @@ Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) {
|
||||
constexpr char longTagParam[] = "LONGTAG____________________________";
|
||||
constexpr float largeValueParam = 2352352.2623433;
|
||||
|
||||
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
|
||||
void BM_ScalarSummary##name##device(int iters) { \
|
||||
testing::StopTiming(); \
|
||||
TensorShape tensorshape(DIMARGS dims); \
|
||||
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
|
||||
void BM_ScalarSummary##name##device(::testing::benchmark::State& state) { \
|
||||
TensorShape tensorshape(DIMARGS dims); \
|
||||
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state); \
|
||||
} \
|
||||
BENCHMARK(BM_ScalarSummary##name##device);
|
||||
|
||||
BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2);
|
||||
|
@ -44,7 +44,7 @@ class DummyDevice : public DeviceBase {
|
||||
}
|
||||
};
|
||||
|
||||
// Helper for comparing ouput and expected output
|
||||
// Helper for comparing output and expected output
|
||||
void ExpectSummaryMatches(const Summary& actual, const string& expected_str) {
|
||||
Summary expected;
|
||||
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected));
|
||||
|
@ -685,7 +685,7 @@ class DeviceKernelOpTest : public OpsTestBase {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0"));
|
||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||
@ -694,7 +694,7 @@ class DeviceKernelOpTest : public OpsTestBase {
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
const char* device_name_ = tensorflow::DEVICE_GPU;
|
||||
#else
|
||||
const char* device_name_ = tensorflow::DEVICE_CPU;
|
||||
@ -711,6 +711,23 @@ template <typename T>
|
||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||
TF_OpKernelContext* ctx);
|
||||
|
||||
REGISTER_OP("StreamOp").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestStream) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
SP_Stream stream = TF_GetStream(ctx, s);
|
||||
// Stream is always null if device is not a pluggable device. More test
|
||||
// cases will be added when pluggable device mechanism is supported.
|
||||
EXPECT_EQ(stream, nullptr);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(s));
|
||||
TF_DeleteStatus(s);
|
||||
};
|
||||
|
||||
SetupOp("StreamOp", "StreamOp", my_compute_func);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
}
|
||||
|
||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||
@ -801,7 +818,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
|
||||
int64_t dim = 1;
|
||||
TF_AllocatorAttributes alloc_attrs;
|
||||
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
alloc_attrs.on_host = 0;
|
||||
#else
|
||||
alloc_attrs.on_host = 1;
|
||||
@ -838,7 +855,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
|
||||
int64_t dim = 0;
|
||||
TF_AllocatorAttributes alloc_attrs;
|
||||
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
alloc_attrs.on_host = 0;
|
||||
#else
|
||||
alloc_attrs.on_host = 1;
|
||||
@ -871,7 +888,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
|
||||
int64_t dim[2] = {2, 3};
|
||||
TF_AllocatorAttributes alloc_attrs;
|
||||
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
alloc_attrs.on_host = 0;
|
||||
#else
|
||||
alloc_attrs.on_host = 1;
|
||||
@ -979,7 +996,7 @@ template <typename T>
|
||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||
TF_OpKernelContext* ctx) {
|
||||
T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
|
||||
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
|
||||
tensor_size_bytes);
|
||||
|
@ -76,7 +76,7 @@ class Tensor {
|
||||
// unknown rank.
|
||||
int dims() const;
|
||||
|
||||
// Returns the number of elements in in demension `d`.
|
||||
// Returns the number of elements in dimension `d`.
|
||||
// REQUIRES: `0 <= d < dims()`
|
||||
int64_t dim_size(int d) const;
|
||||
|
||||
@ -154,7 +154,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
|
||||
// 2. DeleterFunction is defined in the same build artifact that constructed
|
||||
// the std::function (so there isn't confusion about std::function ABI).
|
||||
// Note that 2. is satisifed by the fact that this is a header-only API, where
|
||||
// Note that 2. is satisfied by the fact that this is a header-only API, where
|
||||
// the function implementations are inline.
|
||||
|
||||
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
|
||||
|
@ -67,7 +67,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
|
||||
// mat: A 2-D tensor of dimension [D0, D1]
|
||||
//
|
||||
// Returns:
|
||||
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
||||
// A tensor of dimension [D0, D1], the result for vec * mat.
|
||||
Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
|
||||
auto reshaped = ExpandDims(scope, vec, -1);
|
||||
return Multiply(scope, reshaped, mat);
|
||||
|
@ -84,9 +84,6 @@ class SavedModelAPI {
|
||||
SignatureDefFunction* GetSignatureDefFunction(
|
||||
const std::string& function_path, Status* status);
|
||||
|
||||
// Lists all Conrete Functions available from the SavedModel.
|
||||
std::vector<ConcreteFunction*> ListFunctions();
|
||||
|
||||
// SavedModelAPI is movable, but not copyable.
|
||||
SavedModelAPI(SavedModelAPI&&) = default;
|
||||
SavedModelAPI& operator=(SavedModelAPI&&) = default;
|
||||
@ -151,11 +148,6 @@ inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction(
|
||||
return SignatureDefFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
|
||||
return list.ToVector();
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
@ -404,10 +404,12 @@ Status RestoreSession(const RunOptions& run_options,
|
||||
const uint64 read_start_microseconds = Env::Default()->NowMicros();
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
|
||||
meta_graph.saver_def().restore_op_name(),
|
||||
meta_graph.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, session->get()));
|
||||
if (meta_graph.has_saver_def()) {
|
||||
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
|
||||
meta_graph.saver_def().restore_op_name(),
|
||||
meta_graph.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, session->get()));
|
||||
}
|
||||
// Record walltime spent in restoring graph from disk, but postpone metric
|
||||
// increments until graph init finishes.
|
||||
const uint64 restore_graph_walltime =
|
||||
|
@ -138,7 +138,7 @@ class FreezeTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
|
||||
// "c" isnt dependent on the variable, so nothing should be frozen.
|
||||
// "c" isn't dependent on the variable, so nothing should be frozen.
|
||||
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
|
||||
graph_def, {"c:0"}, "assign", &saved_model_bundle));
|
||||
|
||||
@ -183,7 +183,7 @@ class FreezeTest : public ::testing::Test {
|
||||
}
|
||||
Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
|
||||
// "c" isnt dependent on the variable, so nothing should be frozen.
|
||||
// "c" isn't dependent on the variable, so nothing should be frozen.
|
||||
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
|
||||
graph_def, {"c:0"}, "assign", &saved_model_bundle));
|
||||
|
||||
@ -244,7 +244,7 @@ class FreezeTest : public ::testing::Test {
|
||||
|
||||
Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
|
||||
// "c" isnt dependent on the variable, so nothing should be frozen.
|
||||
// "c" isn't dependent on the variable, so nothing should be frozen.
|
||||
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
|
||||
graph_def, {"c:0"}, "assign", &saved_model_bundle));
|
||||
|
||||
|
@ -115,8 +115,8 @@ genrule(
|
||||
# have control of the full GPU.
|
||||
cmd = "CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location :make_test_graphs) --out_dir $(@D)",
|
||||
exec_tools = [":make_test_graphs"],
|
||||
tags = ["manual"],
|
||||
tools = [":make_test_graphs"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
|
@ -127,7 +127,7 @@ def tf_library(
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --dump_fetch_nodes > $@"),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
# Run tfcompile on the build host, rather than forge, since it's
|
||||
# typically way faster on the local machine.
|
||||
local = 1,
|
||||
@ -162,7 +162,7 @@ def tf_library(
|
||||
"//tensorflow/python/tools:freeze_graph)" +
|
||||
freeze_args
|
||||
),
|
||||
exec_tools = ["//tensorflow/python/tools:freeze_graph"],
|
||||
tools = ["//tensorflow/python/tools:freeze_graph"],
|
||||
tags = tags,
|
||||
)
|
||||
tfcompile_graph = freeze_file
|
||||
@ -242,7 +242,7 @@ def tf_library(
|
||||
" --out_function_object=$(@D)/" + function_object_file +
|
||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||
),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
# Run tfcompile on the build host since it's typically faster on the
|
||||
@ -281,7 +281,7 @@ def tf_library(
|
||||
" --out_session_module=$(@D)/" + session_module_pb +
|
||||
" " + flags
|
||||
),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
local = 1,
|
||||
|
@ -7,6 +7,9 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
@ -283,6 +286,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -291,7 +295,7 @@ cc_library(
|
||||
# Header-only version of "flags" library, for linking from the shared object
|
||||
# without ODR violations.
|
||||
cc_library(
|
||||
name = "flags_headers_only",
|
||||
name = "flags_headers",
|
||||
hdrs = ["flags.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
@ -302,6 +306,11 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "flags_headers_only",
|
||||
deps = [":flags_headers"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
@ -361,6 +370,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
@ -447,8 +457,8 @@ cc_library(
|
||||
# Header-only version of "flags" library, for linking from the shared object
|
||||
# without ODR violations.
|
||||
cc_library(
|
||||
name = "get_compiler_ir_hdrs_only",
|
||||
hdrs = ["get_compiler_ir.h"],
|
||||
name = "get_compiler_ir_hdrs",
|
||||
textual_hdrs = ["get_compiler_ir.h"],
|
||||
visibility = [
|
||||
":internal",
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
@ -463,6 +473,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "get_compiler_ir_hdrs_only",
|
||||
deps = [":get_compiler_ir_hdrs"],
|
||||
)
|
||||
|
||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||
cc_header_only_library(
|
||||
name = "xla_jit_headers_lib",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_kernel_creator",
|
||||
srcs = [
|
||||
@ -481,6 +508,7 @@ cc_library(
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/tf2xla:mlir_bridge_pass",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -520,8 +548,8 @@ cc_library(
|
||||
hdrs = ["resource_operation_safety_analysis.h"],
|
||||
deps = [
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
@ -692,7 +720,6 @@ cc_library(
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:scope_internal",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
@ -705,6 +732,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -732,9 +760,9 @@ cc_library(
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -842,9 +870,12 @@ tf_cc_test(
|
||||
"partially_decluster_pass_test.cc",
|
||||
"rearrange_function_argument_pass_test.cc",
|
||||
],
|
||||
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
|
||||
# error.
|
||||
tags = ["nomsan"] + tf_cuda_tests_tags(),
|
||||
tags = [
|
||||
# TODO(b/141643254) Re-enable msan after fixing
|
||||
# use-of-uninitialized-value error.
|
||||
"nomsan",
|
||||
"no_cuda_asan", # TODO(b/171317460): re-enable.
|
||||
] + tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":common",
|
||||
":compilability_check_util",
|
||||
@ -965,13 +996,13 @@ cc_library(
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
@ -1075,15 +1106,3 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||
cc_header_only_library(
|
||||
name = "xla_jit_headers_lib",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
],
|
||||
)
|
||||
|
@ -34,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
@ -42,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -24,11 +24,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -27,11 +27,11 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
@ -167,8 +167,16 @@ void AllocateAndParseFlags() {
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||
// The `enable_mlir_bridge` flag allows the user to explicitly request that
|
||||
// their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
|
||||
//
|
||||
// The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
|
||||
// user has made an explicit request. That is, if this variable is set to
|
||||
// true, the program honors the user's request as per `enable_mlir_bridge`; if
|
||||
// it's set to false, the default behavior is used (which may run either
|
||||
// bridge, on a per-graph basis).
|
||||
bool enable_mlir_bridge = false;
|
||||
bool enable_mlir_bridge_is_explicit = false;
|
||||
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
@ -184,11 +192,11 @@ void AllocateAndParseFlags() {
|
||||
"XLA clusters."),
|
||||
Flag("tf_xla_check_cluster_input_numerics",
|
||||
&build_ops_flags->tf_xla_check_cluster_input_numerics,
|
||||
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||
"If true then insert CheckNumerics nodes to check all cluster "
|
||||
"inputs."),
|
||||
Flag("tf_xla_check_cluster_output_numerics",
|
||||
&build_ops_flags->tf_xla_check_cluster_output_numerics,
|
||||
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||
"If true then insert CheckNumerics nodes to check all cluster "
|
||||
"outputs."),
|
||||
Flag("tf_xla_disable_constant_folding",
|
||||
&build_ops_flags->tf_xla_disable_constant_folding,
|
||||
@ -217,12 +225,24 @@ void AllocateAndParseFlags() {
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names."),
|
||||
|
||||
Flag("tf_mlir_enable_mlir_bridge",
|
||||
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||
Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
|
||||
&enable_mlir_bridge_is_explicit)});
|
||||
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
if (!enable_mlir_bridge_is_explicit) {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
|
||||
} else if (enable_mlir_bridge) {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
} else {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -38,7 +39,7 @@ struct XlaAutoJitFlag {
|
||||
int32 optimization_level_general;
|
||||
};
|
||||
|
||||
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
|
||||
// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
|
||||
// is:
|
||||
// <number>: sets general and single_gpu setting to the provided number.
|
||||
// single-gpu(<number>): sets the single_gpu setting to the provided number.
|
||||
@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
struct MlirCommonFlags {
|
||||
bool tf_mlir_enable_mlir_bridge;
|
||||
ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
|
@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arg_indices, inputs, variable_infos);
|
||||
constant_arg_indices, inputs, variable_infos, dev);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
|
||||
switch (stage) {
|
||||
|
@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
|
||||
may_alias_resource_update;
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
|
||||
variable_infos);
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constants, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
return cache->Compile(options, function, *args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
@ -274,18 +273,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
@ -522,18 +509,6 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
down_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
|
@ -30,12 +30,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -1801,11 +1801,11 @@ absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
|
||||
"Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
|
||||
"Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
|
||||
"BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
|
||||
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
|
||||
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
|
||||
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
||||
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex",
|
||||
"TensorStridedSliceUpdate",
|
||||
"ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
|
||||
"PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
|
||||
"SplitV", "StridedSlice", "StridedSliceGrad",
|
||||
"ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
|
||||
"Unpack", "DeviceIndex", "TensorStridedSliceUpdate",
|
||||
}}};
|
||||
// clang-format on
|
||||
return result;
|
||||
@ -1990,6 +1990,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"StatelessCase",
|
||||
"StatelessIf",
|
||||
"StatelessMultinomial",
|
||||
"StatelessRandomGetAlg",
|
||||
"StatelessRandomGetKeyCounter",
|
||||
"StatelessRandomGetKeyCounterAlg",
|
||||
"StatelessRandomNormal",
|
||||
"StatelessRandomNormalV2",
|
||||
@ -2040,6 +2042,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"UnsortedSegmentSum",
|
||||
"VarIsInitializedOp",
|
||||
"VariableShape",
|
||||
"Where",
|
||||
"While",
|
||||
"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
@ -2061,6 +2064,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
"XlaSetBound",
|
||||
"XlaSetDynamicDimensionSize",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSpmdFullToShardShape",
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user