Merge branch 'master' into interface_16x8
This commit is contained in:
commit
706dc11f1d
18
.bazelrc
18
.bazelrc
@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
|
||||
build:mkl_threadpool -c opt
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -163,6 +168,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1
|
||||
build:dbg --config=opt -c dbg
|
||||
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
|
||||
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
@ -233,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Enable using platform specific build settings
|
||||
# Enable using platform specific build settings, except when cross-compiling for
|
||||
# mobile platforms.
|
||||
build --enable_platform_specific_config
|
||||
build:android --noenable_platform_specific_config
|
||||
build:ios --noenable_platform_specific_config
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:android --copt=-w
|
||||
build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
@ -256,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build:android --cxxopt=-std=c++14
|
||||
build:android --host_cxxopt=-std=c++14
|
||||
build:ios --cxxopt=-std=c++14
|
||||
build:ios --host_cxxopt=-std=c++14
|
||||
build:linux --cxxopt=-std=c++14
|
||||
build:linux --host_cxxopt=-std=c++14
|
||||
build:macos --cxxopt=-std=c++14
|
||||
|
87
.github/bot_config.yml
vendored
Normal file
87
.github/bot_config.yml
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
- amahendrakar
|
||||
- ravikyram
|
||||
- Saduf2019
|
||||
# A list of assignees for compiler folder
|
||||
compiler_assignees:
|
||||
- joker-eph
|
||||
# Cuda Comment
|
||||
cuda_comment: >
|
||||
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
|
||||
* For TF-GPU - See point 1
|
||||
* For TF-CPU - See point 2
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
|
||||
|
||||
|
||||
Make sure you are using compatible TF and CUDA versions.
|
||||
Please refer following TF version and CUDA version compatibility table.
|
||||
|
||||
| TF | CUDA |
|
||||
|
||||
| :-------------: | :-------------: |
|
||||
|
||||
| 2.1.0 - 2.2.0 | 10.1 |
|
||||
|
||||
| 1.13.1 - 2.0 | 10.0 |
|
||||
|
||||
| 1.5.0 - 1.12.0 | 9.0 |
|
||||
|
||||
* If you have above configuration and using _**Windows**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
|
||||
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
|
||||
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
|
||||
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
|
||||
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
|
||||
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
|
||||
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
|
||||
|
||||
|
||||
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
|
||||
|
||||
|
||||
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
|
||||
|
||||
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
|
||||
|
||||
* Try Google Colab to use TensorFlow.
|
||||
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
|
||||
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
|
||||
* All you need is a good internet connection and you are all set.
|
||||
* Try to build TF from sources by changing CPU optimization flags.
|
||||
|
||||
*Please let us know if this helps.*
|
||||
|
||||
windows_comment: >
|
||||
From the stack trace it looks like you are hitting windows path length limit.
|
||||
* Try to disable path length limit on Windows 10.
|
||||
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
|
||||
|
||||
Please let us know if this helps.
|
23
README.md
23
README.md
@ -103,17 +103,17 @@ open-source software development:
|
||||
|
||||
### Official Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
@ -142,6 +142,7 @@ Build Type | Status
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
|
182
RELEASE.md
182
RELEASE.md
@ -1,3 +1,185 @@
|
||||
# Release 2.3.0
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
|
||||
models will not be impacted.
|
||||
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
|
||||
|
||||
# Release 2.0.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 1.15.3
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 2.2.0
|
||||
|
||||
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
|
||||
|
||||
Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated.
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable.
|
||||
* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines.
|
||||
* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md).
|
||||
* `tf.distribute`:
|
||||
* Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training.
|
||||
* Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy`
|
||||
* Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this.
|
||||
* Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage.
|
||||
* Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation.
|
||||
* Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental.
|
||||
* Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks.
|
||||
* `tf.keras`:
|
||||
* `Model.fit` major improvements:
|
||||
* You can now use custom training logic with `Model.fit` by overriding `Model.train_step`.
|
||||
* Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc)
|
||||
* See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`.
|
||||
* SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of
|
||||
relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models.
|
||||
This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to
|
||||
manually call `Model._set_inputs` when using Custom Training Loops(CTLs).
|
||||
* Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator.
|
||||
This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the
|
||||
`DataAdapter` doesn't need to call the Model is probably preferable.
|
||||
* The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers)
|
||||
* Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode.
|
||||
|
||||
* `tf.lite`:
|
||||
* Enable TFLite experimental new converter by default.
|
||||
* XLA
|
||||
* XLA now builds and works on windows. All prebuilt packages come with XLA available.
|
||||
* XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction
|
||||
) with “compile or throw exception” semantics on CPU and GPU.
|
||||
|
||||
## Breaking Changes
|
||||
* `tf.keras`:
|
||||
* In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer.
|
||||
* Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function.
|
||||
* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`.
|
||||
* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release.
|
||||
* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`.
|
||||
* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed.
|
||||
|
||||
## Known Caveats
|
||||
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* `tf.data`:
|
||||
* Removed `autotune_algorithm` from experimental optimization options.
|
||||
* TF Core:
|
||||
* `tf.constant` always creates CPU tensors irrespective of the current device context.
|
||||
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
|
||||
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
|
||||
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
|
||||
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
|
||||
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
|
||||
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
|
||||
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
|
||||
* Improve error message when attempting to use `None` in data-dependent control flow.
|
||||
* Add `RaggedTensor.numpy()`.
|
||||
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
|
||||
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
|
||||
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
|
||||
* Allow `batch_dims==rank(indices)` in `tf.gather`.
|
||||
* Add support for bfloat16 in `tf.print`.
|
||||
* `tf.distribute`:
|
||||
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
|
||||
* `tf.keras`:
|
||||
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
|
||||
* Allow `pathlib.Path` paths for loading models via Keras API.
|
||||
* `tf.function`/AutoGraph:
|
||||
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
|
||||
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
|
||||
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
|
||||
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
|
||||
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
|
||||
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
|
||||
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
|
||||
* `tf.lite`:
|
||||
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
|
||||
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
|
||||
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
|
||||
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
|
||||
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
|
||||
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
|
||||
* TFLite's unpack op now supports boolean tensor inputs.
|
||||
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
|
||||
* Check for large TFLite tensors.
|
||||
* Fix GPU delegate crash with C++17.
|
||||
* Add 5D support to TFLite `strided_slice`.
|
||||
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
|
||||
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
|
||||
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
|
||||
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
|
||||
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
|
||||
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
|
||||
* `tf.random`:
|
||||
* Various random number generation improvements:
|
||||
* Add a fast path for default `random_uniform`
|
||||
* `random_seed` documentation improvement.
|
||||
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
|
||||
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
|
||||
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
|
||||
* Math and Linear Algebra:
|
||||
* Add `tf.linalg.LinearOperatorTridiag`.
|
||||
* Add `LinearOperatorBlockLowerTriangular`
|
||||
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
|
||||
* Add `tf.math.sobol_sample` op.
|
||||
* Add `tf.math.xlog1py`.
|
||||
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
|
||||
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
|
||||
* TPU Enhancements:
|
||||
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
|
||||
* Support configuring TPU software version from cloud tpu client.
|
||||
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
|
||||
* XLA Support:
|
||||
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
|
||||
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
|
||||
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
|
||||
* Enable `Igamma`, `Igammac` for XLA.
|
||||
* Deterministic Op Functionality:
|
||||
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
|
||||
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
|
||||
* Tracing and Debugging:
|
||||
* Add source, destination name to `_send` traceme to allow easier debugging.
|
||||
* Add traceme event to `fastpathexecute`.
|
||||
* Other:
|
||||
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
|
||||
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
|
||||
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi
|
||||
|
||||
# Release 2.0.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
|
||||
|
||||
It is possible to write models that are secure in a sense that they can safely
|
||||
process untrusted inputs assuming there are no bugs. There are two main reasons
|
||||
to not rely on this: first, it is easy to write models which must not be exposed
|
||||
to not rely on this: First, it is easy to write models which must not be exposed
|
||||
to untrusted inputs, and second, there are bugs in any software system of
|
||||
sufficient complexity. Letting users control inputs could allow them to trigger
|
||||
bugs either in TensorFlow or in dependent libraries.
|
||||
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
|
||||
vulnerability in TensorFlow (although it would be a vulnerability of this
|
||||
hypothetical system).
|
||||
|
||||
As a general rule, it is incorrect behavior for Tensorflow to access memory it
|
||||
As a general rule, it is incorrect behavior for TensorFlow to access memory it
|
||||
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
||||
such behaviors constitute a vulnerability.
|
||||
|
||||
|
@ -114,6 +114,14 @@ http_archive(
|
||||
],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "person_detect_data",
|
||||
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
|
||||
],
|
||||
)
|
||||
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
88
configure.py
88
configure.py
@ -144,7 +144,7 @@ def write_to_bazelrc(line):
|
||||
|
||||
|
||||
def write_action_env_to_bazelrc(var_name, var):
|
||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
||||
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
|
||||
|
||||
|
||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||
@ -205,7 +205,7 @@ def setup_python(environ_cp):
|
||||
# Get PYTHON_BIN_PATH, default is the current running python.
|
||||
default_python_bin_path = sys.executable
|
||||
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
||||
'%s]: ') % default_python_bin_path
|
||||
'{}]: ').format(default_python_bin_path)
|
||||
while True:
|
||||
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
||||
'PYTHON_BIN_PATH',
|
||||
@ -215,9 +215,10 @@ def setup_python(environ_cp):
|
||||
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
||||
break
|
||||
elif not os.path.exists(python_bin_path):
|
||||
print('Invalid python path: %s cannot be found.' % python_bin_path)
|
||||
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
|
||||
else:
|
||||
print('%s is not executable. Is it the python binary?' % python_bin_path)
|
||||
print('{} is not executable. Is it the python binary?'.format(
|
||||
python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = ''
|
||||
|
||||
# Convert python path to Windows style before checking lib and version
|
||||
@ -236,7 +237,7 @@ def setup_python(environ_cp):
|
||||
default_python_lib_path = python_lib_paths[0]
|
||||
python_lib_path = get_input(
|
||||
'Please input the desired Python library path to use. '
|
||||
'Default is [%s]\n' % python_lib_paths[0])
|
||||
'Default is [{}]\n'.format(python_lib_paths[0]))
|
||||
if not python_lib_path:
|
||||
python_lib_path = default_python_lib_path
|
||||
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
||||
@ -252,7 +253,7 @@ def setup_python(environ_cp):
|
||||
# Set-up env variables used by python_configure.bzl
|
||||
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
|
||||
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
|
||||
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
||||
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
||||
|
||||
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
||||
@ -266,7 +267,7 @@ def setup_python(environ_cp):
|
||||
with open(
|
||||
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
||||
'w') as f:
|
||||
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
|
||||
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
|
||||
|
||||
|
||||
def reset_tf_configure_bazelrc():
|
||||
@ -320,11 +321,12 @@ def get_var(environ_cp,
|
||||
Raise the error to avoid infinitely looping.
|
||||
"""
|
||||
if not question:
|
||||
question = 'Do you wish to build TensorFlow with %s support?' % query_item
|
||||
question = 'Do you wish to build TensorFlow with {} support?'.format(
|
||||
query_item)
|
||||
if not yes_reply:
|
||||
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
|
||||
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
|
||||
if not no_reply:
|
||||
no_reply = 'No %s' % yes_reply
|
||||
no_reply = 'No {}'.format(yes_reply)
|
||||
|
||||
yes_reply += '\n'
|
||||
no_reply += '\n'
|
||||
@ -368,7 +370,7 @@ def get_var(environ_cp,
|
||||
print(no_reply)
|
||||
var = False
|
||||
else:
|
||||
print('Invalid selection: %s' % user_input_origin)
|
||||
print('Invalid selection: {}'.format(user_input_origin))
|
||||
return var
|
||||
|
||||
|
||||
@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version):
|
||||
if which('bazel') is None:
|
||||
print('Cannot find bazel. Please install bazel.')
|
||||
sys.exit(0)
|
||||
curr_version = run_shell(
|
||||
['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
|
||||
|
||||
for line in curr_version.split('\n'):
|
||||
if 'Build label: ' in line:
|
||||
curr_version = line.split('Build label: ')[1]
|
||||
break
|
||||
stderr = open(os.devnull, 'wb')
|
||||
curr_version = run_shell(['bazel', '--version'],
|
||||
allow_non_zero=True,
|
||||
stderr=stderr)
|
||||
if curr_version.startswith('bazel '):
|
||||
curr_version = curr_version.split('bazel ')[1]
|
||||
|
||||
min_version_int = convert_version_to_int(min_version)
|
||||
curr_version_int = convert_version_to_int(curr_version)
|
||||
@ -1009,17 +1011,15 @@ def set_tf_cuda_compute_capabilities(environ_cp):
|
||||
default_cuda_compute_capabilities = native_cuda_compute_capabilities
|
||||
|
||||
ask_cuda_compute_capabilities = (
|
||||
'Please specify a list of comma-separated '
|
||||
'CUDA compute capabilities you want to '
|
||||
'build with.\nYou can find the compute '
|
||||
'capability of your device at: '
|
||||
'https://developer.nvidia.com/cuda-gpus.\nPlease'
|
||||
' note that each additional compute '
|
||||
'capability significantly increases your '
|
||||
'build time and binary size, and that '
|
||||
'TensorFlow only supports compute '
|
||||
'capabilities >= 3.5 [Default is: %s]: ' %
|
||||
default_cuda_compute_capabilities)
|
||||
'Please specify a list of comma-separated CUDA compute capabilities '
|
||||
'you want to build with.\nYou can find the compute capability of your '
|
||||
'device at: https://developer.nvidia.com/cuda-gpus. Each capability '
|
||||
'can be specified as "x.y" or "compute_xy" to include both virtual and'
|
||||
' binary GPU code, or as "sm_xy" to only include the binary '
|
||||
'code.\nPlease note that each additional compute capability '
|
||||
'significantly increases your build time and binary size, and that '
|
||||
'TensorFlow only supports compute capabilities >= 3.5 [Default is: '
|
||||
'%s]: ' % default_cuda_compute_capabilities)
|
||||
tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
|
||||
ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
|
||||
@ -1031,8 +1031,23 @@ def set_tf_cuda_compute_capabilities(environ_cp):
|
||||
for compute_capability in tf_cuda_compute_capabilities.split(','):
|
||||
m = re.match('[0-9]+.[0-9]+', compute_capability)
|
||||
if not m:
|
||||
print('Invalid compute capability: %s' % compute_capability)
|
||||
all_valid = False
|
||||
# We now support sm_35,sm_50,sm_60,compute_70.
|
||||
sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)',
|
||||
compute_capability)
|
||||
if not sm_compute_match:
|
||||
print('Invalid compute capability: %s' % compute_capability)
|
||||
all_valid = False
|
||||
else:
|
||||
ver = int(sm_compute_match.group(2))
|
||||
if ver < 30:
|
||||
print(
|
||||
'ERROR: TensorFlow only supports small CUDA compute'
|
||||
' capabilities of sm_30 and higher. Please re-specify the list'
|
||||
' of compute capabilities excluding version %s.' % ver)
|
||||
all_valid = False
|
||||
if ver < 35:
|
||||
print('WARNING: XLA does not support CUDA compute capabilities '
|
||||
'lower than sm_35. Disable XLA when running on older GPUs.')
|
||||
else:
|
||||
ver = float(m.group(0))
|
||||
if ver < 3.0:
|
||||
@ -1223,7 +1238,8 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
|
||||
compile times, but until 16.4 is officially released, we can't depend on it.
|
||||
|
||||
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
See also
|
||||
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
Because it's very annoying to check this manually (to check the MSVC installed
|
||||
versions, you need to use the registry, and it's not clear if Bazel will be
|
||||
@ -1366,8 +1382,13 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
|
||||
_TF_MAX_BAZEL_VERSION)
|
||||
try:
|
||||
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
|
||||
_TF_MAX_BAZEL_VERSION)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print('Error checking bazel version: ', e.output.decode('UTF-8').strip())
|
||||
raise e
|
||||
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
@ -1385,7 +1406,6 @@ def main():
|
||||
# Windows.
|
||||
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_MPI'] = '0'
|
||||
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
@ -517,18 +517,29 @@ package_group(
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
package_group(
|
||||
name = "ndarray_tensor_allow_list",
|
||||
packages = ["//learning/pathways/..."],
|
||||
)
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
package_group(
|
||||
name = "types_whitelist",
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
|
@ -85,7 +85,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:chromiumos": [
|
||||
":tf_attrtype",
|
||||
@ -182,7 +182,7 @@ tf_cuda_library(
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status",
|
||||
@ -216,10 +216,11 @@ tf_cuda_library(
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
@ -232,12 +233,13 @@ cc_library(
|
||||
srcs = ["tf_status.cc"],
|
||||
hdrs = ["tf_status.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
@ -259,10 +261,15 @@ cc_library(
|
||||
name = "tensor_interface",
|
||||
hdrs = ["tensor_interface.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -272,7 +279,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -286,16 +293,17 @@ cc_library(
|
||||
srcs = ["tf_tensor.cc"],
|
||||
hdrs = ["tf_tensor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -311,14 +319,15 @@ tf_cuda_library(
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts",
|
||||
@ -386,8 +395,14 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -426,7 +441,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -457,7 +472,7 @@ tf_cuda_library(
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
@ -484,7 +499,7 @@ tf_cuda_library(
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
|
||||
TF_Status* status) {
|
||||
auto* opts = TFE_NewContextOptions();
|
||||
|
||||
// Reduce GPU memory allocation, and set appropriate config options for TFE
|
||||
// context.
|
||||
auto* config = TF_CreateConfig(
|
||||
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
10);
|
||||
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
|
||||
if (!status->status.ok()) {
|
||||
CHECK(!config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* ctx = TFE_NewContextFromSession(opts, session, status);
|
||||
TF_DeleteBuffer(config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// TODO: retrieve the device string via TFE_ContextListDevices()
|
||||
static const char DEFAULT_CPU_DEVICE[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
|
||||
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
int tensor_id, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
|
||||
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
|
||||
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
|
||||
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
|
||||
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
|
||||
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
|
||||
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
|
||||
shared_name.size());
|
||||
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
|
||||
|
||||
// TODO: consider making this an unknown shape.
|
||||
const int64_t* dims_ptr = nullptr;
|
||||
int num_dims = 0;
|
||||
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
|
||||
/*num_values*/ 0, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* queue = nullptr;
|
||||
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
|
||||
return queue;
|
||||
}
|
||||
|
||||
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, tensor, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
|
||||
if (!status->status.ok()) return;
|
||||
CHECK_EQ(num_retvals, 0);
|
||||
}
|
||||
|
||||
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
|
||||
TF_DataType inputType,
|
||||
TFE_TensorHandle* queue,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
TFE_TensorHandle* ret;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &ret, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||
status->status = tensorflow::errors::Internal(errMsg);
|
||||
}
|
||||
@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
|
||||
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
auto iter = builder->attr_names.insert(attr_name).first;
|
||||
builder->Set(
|
||||
(*iter).c_str(),
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values),
|
||||
num_values));
|
||||
}
|
||||
|
||||
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
|
||||
|
@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
|
||||
// Create a serialized tensorflow.ServerDef proto.
|
||||
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
|
||||
|
||||
// TODO: remove this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
|
||||
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
|
||||
|
||||
// Creates from `session` a new eager context to run a graph function or
|
||||
// sends/recvs, so that these concurrent TFE executions can share (via
|
||||
// `session` and its associated device mgr) the same set of fifo queue resource
|
||||
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
|
||||
// graph function execution can access the same fifo queue resource handles
|
||||
// (associated with devices managed by the device manager, which can be obtained
|
||||
// from `session`).
|
||||
//
|
||||
// TODO: Remove this function once we migrate away from using session.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
|
||||
TF_Session* session, TF_Status* status);
|
||||
|
||||
// TODO: Retire this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
|
||||
TF_Session* session, int tensor_id, TF_DataType inputType,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO: consider folding the 2 APIs below into the ones above.
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||
TF_Session* session, int tensor_id, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||
const char* errMsg);
|
||||
|
||||
|
@ -16,7 +16,6 @@ load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -36,7 +35,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
@ -145,6 +144,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_unified_internal",
|
||||
hdrs = [
|
||||
"c_api_unified_experimental_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
@ -320,6 +337,7 @@ tf_cuda_cc_test(
|
||||
tags = [
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
"notap", # TODO(b/156981931): flaky
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -350,6 +368,38 @@ tf_cuda_cc_test(
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"noasan", # leaks gRPC server instances
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_distributed_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_distributed_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
deps = [
|
||||
":c_api",
|
||||
@ -358,10 +408,13 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -413,7 +466,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api",
|
||||
@ -449,6 +502,8 @@ tf_cuda_library(
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
@ -506,6 +561,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
@ -609,7 +665,6 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"*c_api_tfrt*",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "tensorflow/c/eager/c_api_tfrt.h"
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
@ -727,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
TF_Session* sess, TF_Status* status) {
|
||||
const tensorflow::DeviceMgr* device_mgr = nullptr;
|
||||
status->status = sess->session->LocalDeviceManager(&device_mgr);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
@ -899,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->SyncExecutors();
|
||||
status->status = tensorflow::unwrap(ctx)->AsyncWait();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -924,7 +918,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
|
||||
TF_Status* status);
|
||||
// Indicates that the caller will not be using `h` any more.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
|
@ -30,24 +30,11 @@ namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->at(task_index) =
|
||||
tensorflow::strings::StrCat("localhost:", port);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
@ -101,6 +88,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Read the value of variable `var` and save it into `out_value`.
|
||||
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
|
||||
TFE_TensorHandle** out_value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, out_value, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
@ -243,6 +246,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
|
||||
TFE_TensorHandle* value_handle = nullptr;
|
||||
ReadVariable(ctx, var_handle1, &value_handle);
|
||||
CheckTFE_TensorHandleHasFloats(value_handle, {2});
|
||||
TFE_DeleteTensorHandle(value_handle);
|
||||
|
||||
// Start a new worker to replace task:1
|
||||
ReplaceTaskInServerDef(&server_def, 1);
|
||||
server_def.set_task_index(1);
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
// Update server def to replace the remote device with the device info on the
|
||||
// new worker (different incarnation ID).
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// The device of var_handle0 is local device which is the same before and
|
||||
// after cluster update. Remove resource with valid device should succeed.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle0, status);
|
||||
TFE_OpSetDevice(op, dev0_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
// The device of var_handle1 is remote device, which was replaced during
|
||||
// cluster update. Removing resource with invalid device should fail
|
||||
// gracefully (i.e., with error status) instead of crashing with segfaults.
|
||||
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle1, status);
|
||||
TFE_OpSetDevice(op, dev1_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
@ -282,6 +381,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
@ -310,4 +410,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||
}
|
||||
|
||||
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||
|
||||
const string first_name =
|
||||
keep_localhost_for_first_connect ? "localhost" : "abc";
|
||||
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
|
||||
tensorflow::Status status2;
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
|
||||
|
||||
// Rename local device
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev1_name =
|
||||
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
|
||||
|
||||
// Another renaming of local device
|
||||
const string second_name = "def";
|
||||
server_def.set_job_name(second_name);
|
||||
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
|
||||
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
|
||||
absl::StrCat(second_name, ":",
|
||||
tensorflow::testing::PickUnusedPortOrDie());
|
||||
|
||||
serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle2, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
TFE_DeleteTensorHandle(var_handle2);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||
}
|
||||
|
||||
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
|
||||
|
||||
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
|
||||
|
||||
} // namespace
|
||||
|
506
tensorflow/c/eager/c_api_distributed_test.cc
Normal file
506
tensorflow/c/eager/c_api_distributed_test.cc
Normal file
@ -0,0 +1,506 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
// Add the values of three variables on three different tasks.
|
||||
string AddVariablesFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'AddVariablesFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'sum'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read1'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read2'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add1'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read1:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add2'"
|
||||
" op: 'Add'"
|
||||
" input: 'add1:z:0'"
|
||||
" input: 'read2:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'sum'"
|
||||
" value: 'add2:z:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
||||
CHECK_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
||||
bool initialized = false;
|
||||
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
EXPECT_EQ(initialized, true);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(is_initialized[0]);
|
||||
TFE_DeleteOp(op);
|
||||
delete status;
|
||||
}
|
||||
|
||||
void TestFunctionWithPackedInput(const bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
// Create one variable per task.
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Add a sync point in order to make sure that variables have been initialized
|
||||
// before the function execution starts.
|
||||
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
||||
VarIsInitialized(ctx, h1);
|
||||
VarIsInitialized(ctx, h2);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
TFE_TensorHandle* packed_handle =
|
||||
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||
|
||||
const string composite_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Register and run a function which returns the sum of 3 variables.
|
||||
const string function_def = AddVariablesFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(func, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(packed_handle);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(sum, 6.0);
|
||||
|
||||
TFE_DeleteTensorHandle(h0);
|
||||
TFE_DeleteTensorHandle(h1);
|
||||
TFE_DeleteTensorHandle(h2);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/true);
|
||||
}
|
||||
|
||||
string VariableAddFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'VariableAddFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var0'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'var0_value'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read0:value:0'"
|
||||
" device: '/job:localhost/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'identity'"
|
||||
" op: 'Identity'"
|
||||
" input: 'add:z:0'"
|
||||
" device: '/job:localhost/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'var0_value'"
|
||||
" value: 'identity:output:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
||||
public:
|
||||
FunctionErrorInjectionPass(string error_node, string error_device)
|
||||
: error_node_(error_node), error_device_(error_device) {}
|
||||
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
std::unique_ptr<tensorflow::Graph>* graph,
|
||||
tensorflow::FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override {
|
||||
// Inject failure to function instantiation if finding a node that contains
|
||||
// the given node name (error_node_) and requested device (error_device_).
|
||||
for (const auto node : graph->get()->nodes()) {
|
||||
if (node->name().find(error_node_) != string::npos &&
|
||||
node->requested_device() == error_device_) {
|
||||
return tensorflow::errors::Internal("Injected graph pass error.");
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const string error_node_;
|
||||
const string error_device_;
|
||||
};
|
||||
|
||||
void TestDistributedFunctionCancellation(bool inject_error) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
if (inject_error) {
|
||||
// Inject a function optimization pass failure when it sees the 'read0' op
|
||||
// having a requested device `dev2_name`. During execution:
|
||||
// * task:0 processes the main function `VariableAddFunction` and places
|
||||
// the read0 op on task:2
|
||||
// * task:0 partitions the main function with a subgraph containing read0
|
||||
// sent to task:2
|
||||
// * task:2 graph pass reports an error when it sees read0 with dev2_name
|
||||
tensorflow::function_optimization_registration::
|
||||
FunctionOptimizationPassRegistration register_test_pass(
|
||||
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle, nullptr);
|
||||
|
||||
const string function_def = VariableAddFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, var_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
|
||||
if (inject_error) {
|
||||
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
ASSERT_EQ(sum, 4.0);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(var_handle);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* h1_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h0_task1);
|
||||
TFE_DeleteTensorHandle(h1_task1);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
} // namespace
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
@ -638,3 +639,35 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
TFE_TensorHandle** handles,
|
||||
int* num_handles,
|
||||
TF_Status* status) {
|
||||
std::vector<tensorflow::TensorHandle*> tensor_handles;
|
||||
tensor_handles.reserve(*num_handles);
|
||||
for (int i = 0; i < *num_handles; ++i) {
|
||||
tensor_handles.push_back(
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::TensorHandle::CreatePackedHandle(
|
||||
std::move(tensor_handles), context, &handle);
|
||||
return tensorflow::wrap(handle);
|
||||
}
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
@ -541,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||
|
||||
// Create a packed TensorHandle with the given list of TensorHandles.
|
||||
// If `handles` are on the same device, assign the same device to the packed
|
||||
// handle; if `handles` are on different deivces, assign a CompositeDevice to
|
||||
// it.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure soft device placement policy for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure device placement policy logging for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -19,37 +19,22 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void TestRemoteExecute(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
@ -351,74 +336,4 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* h1_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h0_task1);
|
||||
TFE_DeleteTensorHandle(h1_task1);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
}
|
||||
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
||||
|
||||
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
||||
TF_Status* status) {
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TEST(CAPI, Variables) {
|
||||
// Variables use resource handles, so this is really a test for resource
|
||||
// tensor handling.
|
||||
@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
@ -1248,6 +1203,8 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
h = nullptr;
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(op);
|
||||
|
@ -18,7 +18,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -133,6 +135,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (!device_name.empty()) {
|
||||
TFE_OpSetDevice(op, device_name.c_str(), status);
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
@ -244,3 +298,23 @@ bool GetDeviceName(TFE_Context* ctx, string* device_name,
|
||||
TF_DeleteDeviceList(devices);
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
// Return a tensor handle containing a float scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
|
||||
@ -42,6 +43,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return a variable handle referring to a variable with the given initial value
|
||||
// on the given device.
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name = "");
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
@ -67,4 +73,11 @@ TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
|
||||
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
|
||||
const char* device_type);
|
||||
|
||||
// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
|
||||
tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
|
||||
int num_tasks);
|
||||
|
||||
// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks);
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -26,6 +28,51 @@ using tensorflow::string;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
|
||||
|
||||
static FactoriesMap& GetFactories() {
|
||||
static FactoriesMap* factories = new FactoriesMap;
|
||||
return *factories;
|
||||
}
|
||||
|
||||
static const char* default_factory = "<unset>";
|
||||
|
||||
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
assert((!GetFactories().count(name)) ||
|
||||
(GetFactories()[name] == factory) &&
|
||||
"Duplicate tracing factory registration");
|
||||
GetFactories()[name] = factory;
|
||||
}
|
||||
|
||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||
|
||||
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
auto entry = GetFactories().find(default_factory);
|
||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||
string msg = absl::StrCat(
|
||||
"No tracing engine factory has been registered with the key '",
|
||||
default_factory, "' (available: ");
|
||||
// Ensure deterministic (sorted) order in the error message
|
||||
std::set<string> factories_sorted;
|
||||
for (const auto& factory : GetFactories())
|
||||
factories_sorted.insert(factory.first);
|
||||
const char* comma = "";
|
||||
for (const string& factory : factories_sorted) {
|
||||
msg += comma + factory;
|
||||
comma = ", ";
|
||||
}
|
||||
msg += ")";
|
||||
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
//
|
||||
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
void TF_SetTracingImplementation(const char* name) {
|
||||
tensorflow::internal::SetDefaultTracingEngine(name);
|
||||
}
|
||||
|
||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||
// given tracing context.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
||||
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList* outputs, TF_Status* s) {
|
||||
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
return func;
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
return wrap(unwrap(func)->AddParameter(dtype, s));
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return wrap(unwrap(o)->outputs[i]);
|
||||
}
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status* s) {
|
||||
unwrap(o)->outputs.push_back(unwrap(tensor));
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
|
@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
|
||||
// Creates a context for tracing the execution of operations into a function.
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
||||
// This allows the client to swap the implementation of the tracing engine.
|
||||
// Any future call to TF_CreateFunction will use the implementation defined
|
||||
// here.
|
||||
void TF_SetTracingImplementation(const char* name);
|
||||
|
||||
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||
// tracing, a function can be obtained by TF_FinalizeFunction.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
|
||||
|
||||
// Creates a context for eager execution of operations.
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
// Add a new parameter to a TensorFlow Function.
|
||||
// TODO(aminim): what about shape?
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||
@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||
// an operation.
|
||||
// It just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
// TODO(aminim): the description above isn't clear with respect to
|
||||
// TF_OutputListNumOutputs and the current eager implementation which requires
|
||||
// the number of outputs to be set by the client.
|
||||
// an operation, or provided to create a function.
|
||||
// When executing an operation in an eager context, the expected number of
|
||||
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
// Prepare tracing to the expected number of output for an operation.
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
|
||||
// Return the number of outputs in the list.
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
// Return the `i`th output in the list.
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
// Append a tensor at the end of the output list, growing its size by one.
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status*);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph. The output tensors are
|
||||
@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||
// context. The returned TF_GraphToFunction must be deleted by the client.
|
||||
// context. The provided `ctx` is consumed by this API call and deleted.
|
||||
// The returned TF_AbstractFunction must be deleted by the client,
|
||||
// TODO(aminim): clarify the contract on the state of the context after this
|
||||
// call.
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList*, TF_Status*);
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
|
||||
|
@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't add function parameter on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't use finalize function on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
};
|
||||
|
||||
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
|
||||
// adding them to the graph.
|
||||
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
||||
// "execution" of operation, i.e. adding them to the function.
|
||||
class GraphContext : public ExecutionContext {
|
||||
public:
|
||||
GraphContext()
|
||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||
explicit GraphContext(const char* name)
|
||||
: ExecutionContext(kKind),
|
||||
graph_(new TF_Graph(), TF_DeleteGraph),
|
||||
name_(name) {}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext {
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
if (tf_opdesc == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||
if (!graph_tensor) {
|
||||
@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext {
|
||||
}
|
||||
}
|
||||
|
||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||
const GraphTensor* inputs, int num_outputs,
|
||||
const GraphTensor* outputs, TF_Status* status) const {
|
||||
std::vector<TF_Output> graph_inputs;
|
||||
graph_inputs.resize(num_inputs);
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_OperationDescription* opdesc =
|
||||
TF_NewOperation(graph_.get(), "Placeholder",
|
||||
absl::StrCat("_input_", inputs_.size()).c_str());
|
||||
TF_SetAttrType(opdesc, "dtype", dtype);
|
||||
auto* operation = TF_FinishOperation(opdesc, s);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
|
||||
inputs_.push_back(TF_Output{operation, 0});
|
||||
return new GraphTensor(inputs_.back(), this);
|
||||
}
|
||||
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.resize(num_outputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
graph_inputs[i] = inputs[i].output;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
graph_outputs[i] = outputs[i].output;
|
||||
graph_outputs.reserve(outputs->outputs.size());
|
||||
for (AbstractTensor* abstract_output : outputs->outputs) {
|
||||
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
|
||||
if (!output) {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Returning a non-graph tensor from a function has not "
|
||||
"been implemented yet.");
|
||||
return nullptr;
|
||||
}
|
||||
graph_outputs.push_back(output->output);
|
||||
}
|
||||
|
||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||
graph_inputs.size(), graph_inputs.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, fn_name, status);
|
||||
func->func = TF_GraphToFunction(
|
||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
return func.release();
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext {
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
std::vector<TF_Output> inputs_;
|
||||
const char* name_;
|
||||
};
|
||||
|
||||
// Helper that converts the graph currently held in the context into a function.
|
||||
static AbstractFunction* ExecutionContextToFunction(
|
||||
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const AbstractTensor* inputs, int num_outputs,
|
||||
const AbstractTensor* outputs, TF_Status* status) {
|
||||
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
|
||||
if (graph_ctx == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"fn_body is not a TF_GraphContext.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
|
||||
if (!graph_inputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
|
||||
if (!graph_outputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
GraphFunction* func = new GraphFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||
num_outputs, graph_outputs, status);
|
||||
return func;
|
||||
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
return new GraphContext(name);
|
||||
}
|
||||
|
||||
// Register the tracing implemented in this file as the default tracing engine.
|
||||
static bool register_tracing = [] {
|
||||
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||
SetDefaultTracingEngine("graphdef");
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Graph API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::GraphContext());
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
|
||||
unwrap(inputs), num_outputs,
|
||||
unwrap(outputs), status));
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
@ -57,7 +58,7 @@ T* dyncast(S source) {
|
||||
// GraphContext and vice-versa).
|
||||
class AbstractTensor {
|
||||
protected:
|
||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
||||
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
|
||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
@ -100,7 +101,7 @@ class AbstractFunction {
|
||||
// on a given context, with the same or different input tensors.
|
||||
class AbstractOp {
|
||||
protected:
|
||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
||||
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
|
||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
@ -128,7 +129,7 @@ class AbstractOp {
|
||||
// eager implementation or to a graph implementation.
|
||||
struct ExecutionContext {
|
||||
protected:
|
||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
||||
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
|
||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
|
||||
public:
|
||||
@ -148,6 +149,17 @@ struct ExecutionContext {
|
||||
// Creates an empty AbstractOperation suitable to use with this context.
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
|
||||
|
||||
// Registers a functions with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
@ -156,6 +168,11 @@ struct ExecutionContext {
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
void SetDefaultTracingEngine(const char* name);
|
||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||
FactoryFunction factory);
|
||||
|
||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||
// C++ implementation, and back.
|
||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||
|
@ -29,7 +29,12 @@ using tensorflow::string;
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifedCAPI, TestBasicEager) {
|
||||
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
|
||||
protected:
|
||||
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
|
||||
};
|
||||
|
||||
TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -81,33 +86,18 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "double";
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_CreateFunction(fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
@ -123,16 +113,13 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
string fn_name = "double";
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractTensor(output_t);
|
||||
TF_AbstractFunction* func =
|
||||
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -173,18 +160,161 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Status* s = status.get();
|
||||
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "two_adds";
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
TF_AbstractTensor* add_output1;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output1 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Same with a second "Add" computing `arg1 + arg1`.
|
||||
TF_AbstractTensor* add_output2;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output2 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Finalize the function by providing the returned values.
|
||||
TF_AbstractFunction* func;
|
||||
{
|
||||
// We want to return the output of both add operations, create a new list
|
||||
// and populate it.
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListPushBack(func_outputs, add_output1, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_OutputListPushBack(func_outputs, add_output2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* We traced so far this function:
|
||||
*
|
||||
* def two_adds(a, b):
|
||||
* my_add1 = a + b
|
||||
* my_add2 = b + b
|
||||
* return my_add1, my_add2
|
||||
*
|
||||
* Now we will execute this function with an eager context:
|
||||
*
|
||||
* output1, output2 = two_adds(2.0, 3.0)
|
||||
*
|
||||
* and check that we got 5.0 and 6.0 as results.
|
||||
*/
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build two abstract input tensors as function arguments.
|
||||
std::vector<TF_AbstractTensor*> func_args;
|
||||
{
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
}
|
||||
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||
eager_execution_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||
|
||||
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
|
||||
float results[2];
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
|
||||
TF_DeleteTensor(f_t);
|
||||
}
|
||||
ASSERT_EQ(results[0], 5.0);
|
||||
ASSERT_EQ(results[1], 6.0);
|
||||
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
}
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -192,18 +322,15 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
|
||||
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, func);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -221,10 +348,10 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -242,7 +369,7 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -272,7 +399,8 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
@ -288,10 +416,11 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -348,5 +477,8 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Values("graphdef", "mlir"));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -59,6 +59,20 @@ class AbstractContextInterface {
|
||||
virtual AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
|
||||
|
||||
// Create a tensor instance from the given data buffer and description.
|
||||
// `memory_releaser` will be called on destruction, and it's responsible for
|
||||
// cleaning up the underlying buffer. `convert_string` indicates whether it
|
||||
// has to handle tstring conversion. Expected to be removed once tstring
|
||||
// migration is done.
|
||||
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
|
||||
const int64_t* dims,
|
||||
int num_dims, void* data,
|
||||
size_t len, bool convert_string,
|
||||
MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
@ -87,6 +101,9 @@ class AbstractContextInterface {
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -27,6 +27,7 @@ cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -38,11 +39,30 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_testlib",
|
||||
testonly = 1,
|
||||
srcs = ["parallel_device_testlib.cc"],
|
||||
hdrs = ["parallel_device_testlib.h"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -52,3 +72,40 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_remote_test",
|
||||
srcs = ["parallel_device_remote_test.cc"],
|
||||
# TODO(b/136478427): Enable global heap checking when servers shut down
|
||||
# cleanly.
|
||||
args = ["--heap_check=local"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||
# to TensorFlow by default, just used in a few tests.
|
||||
filegroup(
|
||||
name = "parallel_device_ops_srcs",
|
||||
srcs = ["parallel_device_ops.cc"],
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_ops",
|
||||
srcs = [":parallel_device_ops_srcs"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -92,6 +92,10 @@ class ParallelDevice {
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
@ -275,6 +319,11 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
std::vector<MaybeParallelTensorOwned> outputs;
|
||||
outputs.reserve(t->num_tensors());
|
||||
for (int i = 0; i < t->num_tensors(); ++i) {
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executors_[i].get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TensorHandlePtr this_output(
|
||||
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
|
||||
outputs.emplace_back(std::move(this_output));
|
||||
@ -282,6 +331,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
|
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||
// it to core TF. Right now the eager C API does some checking of op
|
||||
// registrations before calling into custom devices, but we may be able to avoid
|
||||
// that.
|
||||
REGISTER_OP("DeviceID")
|
||||
.Output("device_id: int64")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -0,0 +1,147 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost", ":", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestRemoteBasic) {
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
std::string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
|
||||
serialized.size(), status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:worker/replica:0/task:1/device:CPU:0",
|
||||
"/job:worker/replica:0/task:2/device:CPU:0");
|
||||
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
std::string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
|
||||
serialized.size(), status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), in_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Loop to make synchronization failures more deterministic
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
TensorHandlePtr multiply_result(
|
||||
Multiply(context.get(), combined_value.get(), combined_value.get(),
|
||||
status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> out_components;
|
||||
ExtractPerDeviceValues(context.get(), multiply_result.get(),
|
||||
&out_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<float>(out_components[0].get(), 9.);
|
||||
ExpectScalarEq<float>(out_components[1].get(), 4.);
|
||||
}
|
||||
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||
@ -28,363 +29,6 @@ limitations under the License.
|
||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||
// another option.
|
||||
|
||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||
// template argument requires passing a function pointer to
|
||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
// A helper for performing common operations on variables. A much more
|
||||
// restricted stand-in for tf.Variable in Python.
|
||||
class Variable {
|
||||
public:
|
||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||
// indication of the dtype of the variable's value.
|
||||
//
|
||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||
// separate static method which returns a status.
|
||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||
: handle_(handle), type_(type) {}
|
||||
|
||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||
// object.
|
||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status);
|
||||
// Dereferences the backing buffer for the variable. Note that since this can
|
||||
// fail (it runs operations), it must be called explicitly and the resulting
|
||||
// `status` checked.
|
||||
void Destroy(TFE_Context* context, TF_Status* status);
|
||||
|
||||
// Reads from the variable.
|
||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||
// Assigns a new value to the variable.
|
||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||
// Adds `value` to the existing value of the variable.
|
||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status);
|
||||
|
||||
private:
|
||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||
// AssignSub, ...).
|
||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status);
|
||||
|
||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||
// buffer.
|
||||
TFE_TensorHandle* handle_;
|
||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||
// dtype of read operations).
|
||||
TF_DataType type_;
|
||||
};
|
||||
|
||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
// Use the special GUID for no buffer sharing
|
||||
//
|
||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||
// only reasonable way to make variables with no aliasing using the eager C
|
||||
// API.
|
||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||
no_sharing.length());
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return new Variable(var_handle, type);
|
||||
}
|
||||
|
||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||
// Free the backing buffer for the variable.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Delete the variable handle itself.
|
||||
TFE_DeleteTensorHandle(handle_);
|
||||
}
|
||||
|
||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(var_value);
|
||||
}
|
||||
|
||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), value, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
|
||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||
// deleted.
|
||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||
const int num_bytes = sizeof(float);
|
||||
float* values = new float[1];
|
||||
values[0] = v;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status) {
|
||||
const int num_bytes = v.size() * sizeof(float);
|
||||
float* values = new float[v.size()];
|
||||
memcpy(values, v.data(), num_bytes);
|
||||
int64_t dims = v.size();
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||
&FloatDeallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles[num_replicas];
|
||||
int num_retvals = num_replicas;
|
||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
(*components)[i].reset(result_handles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
TFE_OpAddInput(op.get(), components[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), second, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), first_device, status);
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device) {
|
||||
// Register the custom device
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
// device.
|
||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||
to_delete->Destroy(context, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
delete to_delete;
|
||||
};
|
||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||
status.get()),
|
||||
variable_deleter);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||
// component device.
|
||||
{
|
||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
variable->Assign(context, initial_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read from the variable and verify that we have a parallel tensor.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Add a parallel tensor with different values on each device to the variable.
|
||||
{
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value =
|
||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read the variable and verify that each component has the right modified
|
||||
// value.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -498,8 +142,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
@ -630,7 +274,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
ExpectScalarEq<float>(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
@ -644,8 +288,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
ExpectScalarEq<float>(first_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
@ -806,8 +450,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
@ -909,8 +553,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
|
308
tensorflow/c/eager/parallel_device/parallel_device_testlib.cc
Normal file
308
tensorflow/c/eager/parallel_device/parallel_device_testlib.cc
Normal file
@ -0,0 +1,308 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||
// integration testing rather than purely testing the parallel device. They
|
||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||
// another option.
|
||||
|
||||
|
||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
// Use the special GUID for no buffer sharing
|
||||
//
|
||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||
// only reasonable way to make variables with no aliasing using the eager C
|
||||
// API.
|
||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||
no_sharing.length());
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return new Variable(var_handle, type);
|
||||
}
|
||||
|
||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||
// Free the backing buffer for the variable.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Delete the variable handle itself.
|
||||
TFE_DeleteTensorHandle(handle_);
|
||||
}
|
||||
|
||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(var_value);
|
||||
}
|
||||
|
||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), value, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
|
||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||
// deleted.
|
||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||
const int num_bytes = sizeof(float);
|
||||
float* values = new float[1];
|
||||
values[0] = v;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status) {
|
||||
const int num_bytes = v.size() * sizeof(float);
|
||||
float* values = new float[v.size()];
|
||||
memcpy(values, v.data(), num_bytes);
|
||||
int64_t dims = v.size();
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||
&FloatDeallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles[num_replicas];
|
||||
int num_retvals = num_replicas;
|
||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
(*components)[i].reset(result_handles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), second, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), first_device, status);
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device) {
|
||||
// Register the custom device
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
// device.
|
||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||
to_delete->Destroy(context, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
delete to_delete;
|
||||
};
|
||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||
status.get()),
|
||||
variable_deleter);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||
// component device.
|
||||
{
|
||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
variable->Assign(context, initial_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read from the variable and verify that we have a parallel tensor.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Add a parallel tensor with different values on each device to the variable.
|
||||
{
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value =
|
||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read the variable and verify that each component has the right modified
|
||||
// value.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
// Compute the device ID twice and verify the result
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
174
tensorflow/c/eager/parallel_device/parallel_device_testlib.h
Normal file
174
tensorflow/c/eager/parallel_device/parallel_device_testlib.h
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||
// template argument requires passing a function pointer to
|
||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
// A helper for performing common operations on variables. A much more
|
||||
// restricted stand-in for tf.Variable in Python.
|
||||
class Variable {
|
||||
public:
|
||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||
// indication of the dtype of the variable's value.
|
||||
//
|
||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||
// separate static method which returns a status.
|
||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||
: handle_(handle), type_(type) {}
|
||||
|
||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||
// object.
|
||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status);
|
||||
// Dereferences the backing buffer for the variable. Note that since this can
|
||||
// fail (it runs operations), it must be called explicitly and the resulting
|
||||
// `status` checked.
|
||||
void Destroy(TFE_Context* context, TF_Status* status);
|
||||
|
||||
// Reads from the variable.
|
||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||
// Assigns a new value to the variable.
|
||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||
// Adds `value` to the existing value of the variable.
|
||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status);
|
||||
|
||||
private:
|
||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||
// AssignSub, ...).
|
||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status);
|
||||
|
||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||
// buffer.
|
||||
TFE_TensorHandle* handle_;
|
||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||
// dtype of read operations).
|
||||
TF_DataType type_;
|
||||
};
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status);
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
|
||||
|
||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status);
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status);
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status);
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device);
|
||||
|
||||
// Implementations of templated functions ******************************
|
||||
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
TFE_OpAddInput(op.get(), components[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
EXPECT_EQ(expected_value,
|
||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
30
tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
Normal file
30
tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
Normal file
@ -0,0 +1,30 @@
|
||||
# Experimental gcs filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Filesystem implementation for GCS environments
|
||||
tf_cc_shared_object(
|
||||
name = "gcs_filesystem",
|
||||
framework_so = [],
|
||||
linkstatic = False,
|
||||
per_os_targets = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":gcs_filesystem_impl"],
|
||||
)
|
||||
|
||||
# The real implementation of the filesystem.
|
||||
cc_library(
|
||||
name = "gcs_filesystem_impl",
|
||||
srcs = ["gcs_filesystem.cc"],
|
||||
copts = select({
|
||||
"//conditions:default": [],
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
],
|
||||
)
|
@ -0,0 +1,72 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_writable_file
|
||||
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||
info->plugin_memory_free = plugin_memory_free;
|
||||
info->num_schemes = 1;
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "gs");
|
||||
}
|
@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory {
|
||||
delete_function_(delete_function),
|
||||
rendezvous_builder_(rendezvous_builder) {}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
Status NewServer(const ServerDef& server_def, const Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
TF_RETURN_IF_ERROR(CGrpcServer::Create(
|
||||
server_def, init_function_, start_function_, stop_function_,
|
||||
|
@ -31,9 +31,6 @@ cc_library(
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
|
||||
# so that we can depend on c/eager/c_api_unified_experimental.h.
|
||||
features = ["-layering_check"],
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
@ -41,6 +38,8 @@ cc_library(
|
||||
":concrete_function_type",
|
||||
":function_metadata",
|
||||
":function_metadata_type",
|
||||
":tensorhandle_list",
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
@ -156,10 +155,43 @@ cc_library(
|
||||
"saved_model_api_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list",
|
||||
srcs = [
|
||||
"tensorhandle_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list_type",
|
||||
hdrs = [
|
||||
"tensorhandle_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
||||
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
|
||||
// internal header, and implement this function.
|
||||
return nullptr;
|
||||
const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
|
@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
return tensorflow::wrap(result.release());
|
||||
}
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
return tensorflow::wrap(result.release());
|
||||
}
|
||||
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) {
|
||||
delete tensorflow::unwrap(model);
|
||||
}
|
||||
|
||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
const char* function_path,
|
||||
TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetFunction(function_path, &result);
|
||||
tensorflow::unwrap(model)->GetFunction(function_path, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
|
||||
&result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
|
||||
return new TF_ConcreteFunctionList{
|
||||
tensorflow::unwrap(model)->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -18,13 +18,18 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_SavedModel {
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
||||
};
|
||||
typedef struct TF_SavedModel TF_SavedModel;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
|
||||
return tensorflow::unwrap(list)->size();
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
|
||||
int i) {
|
||||
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
|
||||
}
|
||||
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
@ -24,6 +24,7 @@ exports_files(
|
||||
"concrete_function_list.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"tensorhandle_list.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
@ -39,6 +40,7 @@ cc_library(
|
||||
":concrete_function_list",
|
||||
":function_metadata",
|
||||
":saved_model_api",
|
||||
":tensorhandle_list",
|
||||
],
|
||||
)
|
||||
|
||||
@ -61,3 +63,8 @@ alias(
|
||||
name = "saved_model_api",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorhandle_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a list of TensorHandles implicitly captured by this function.
|
||||
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function.
|
||||
|
@ -21,19 +21,27 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT size_t
|
||||
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
|
||||
TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
// Returns the `i`th TF_ConcreteFunction in the list.
|
||||
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_ConcreteFunctionList* list, int i);
|
||||
|
||||
// Deletes `list`.
|
||||
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
|
||||
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
@ -0,0 +1,43 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
|
||||
const TF_TensorHandleList* list);
|
||||
|
||||
// Returns the `i`th TFE_TensorHandle in the list.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
|
||||
const TF_TensorHandleList* list, int i);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
@ -178,7 +178,7 @@ cc_library_with_android_deps(
|
||||
name = "ops",
|
||||
srcs = ["framework/ops.cc"],
|
||||
hdrs = ["framework/ops.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -197,7 +197,7 @@ cc_library_with_android_deps(
|
||||
"framework/scope_internal.h",
|
||||
],
|
||||
hdrs = ["framework/scope.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
],
|
||||
@ -237,7 +237,7 @@ cc_library_with_android_deps(
|
||||
name = "client_session",
|
||||
srcs = ["client/client_session.cc"],
|
||||
hdrs = ["client/client_session.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
@ -275,7 +275,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/const_op.cc"],
|
||||
hdrs = ["ops/const_op.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":ops",
|
||||
@ -304,7 +304,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/while_loop.cc"],
|
||||
hdrs = ["ops/while_loop.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":cc_ops",
|
||||
|
@ -57,7 +57,22 @@ cc_library(
|
||||
"tensor.h",
|
||||
],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle",
|
||||
hdrs = [
|
||||
"tensorhandle.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
":tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||
@ -40,6 +41,7 @@ class Runtime {
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||
@ -63,6 +65,7 @@ class Runtime {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Status is a wrapper around an error code and an optional error message.
|
||||
@ -57,6 +58,7 @@ class Status {
|
||||
friend class RuntimeBuilder;
|
||||
friend class Runtime;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TF_Status*, and takes ownership of it.
|
||||
explicit Status(TF_Status* status) : status_(status) {}
|
||||
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
|
@ -19,30 +19,53 @@ limitations under the License.
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Tensor represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
// TODO(bmzhao): Add a factory function that constructs a Tensor from a char
|
||||
// buffer, with an options struct (to specify the buffer's layout, device?,
|
||||
// whether to create a TFRT or TF tensor, whether we should take ownership of
|
||||
// the memory, etc). This requires extending TF_NewTensor with an options
|
||||
// struct:
|
||||
// https://github.com/tensorflow/tensorflow/blob/3c520614a3c056d56afdc79b59979b9b0087f8b9/tensorflow/c/tf_tensor.h#L77-L80
|
||||
using DeleterCallback = std::function<void(void*, size_t)>;
|
||||
|
||||
// Constructs a Tensor from user provided buffer.
|
||||
//
|
||||
// Params:
|
||||
// dtype - The dtype of the tensor's data.
|
||||
// shape - A shape vector, where each element corresponds to the size of
|
||||
// the tensor's corresponding dimension.
|
||||
// data - Pointer to a buffer of memory to construct a Tensor out of.
|
||||
// len - The length (in bytes) of `data`
|
||||
// deleter - A std::function to be called when the Tensor no longer needs the
|
||||
// memory in `data`. This can be used to free `data`, or
|
||||
// perhaps decrement a refcount associated with `data`, etc.
|
||||
// status - Set to OK on success and an error on failure.
|
||||
// Returns:
|
||||
// If an error occurred, status->ok() will be false, and the returned
|
||||
// Tensor must not be used.
|
||||
// TODO(bmzhao): Add Runtime as an argument to this function so we can swap to
|
||||
// a TFRT backed tensor.
|
||||
// TODO(bmzhao): Add benchmarks on overhead for this function; we can
|
||||
// consider using int64_t* + length rather than vector.
|
||||
static Tensor FromBuffer(TF_DataType dtype, const std::vector<int64_t>& shape,
|
||||
void* data, size_t len, DeleterCallback deleter,
|
||||
Status* status);
|
||||
|
||||
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
|
||||
// we should offer a way to deep copy the tensor into a new tensor, which
|
||||
// owns the underlying memory. This could be a .deepcopy()/clone() method.
|
||||
|
||||
// TODO(bmzhao): In the future, we want to relax the non-copyability
|
||||
// constraint. To do so, we can add a C API function that acts like CopyFrom:
|
||||
// constraint. To do so, we can add a C API function that acts like
|
||||
// CopyFrom:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
|
||||
|
||||
// Tensor is movable, but not copyable
|
||||
@ -85,6 +108,16 @@ class Tensor {
|
||||
// This object retains ownership of the pointer.
|
||||
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
|
||||
|
||||
struct DeleterStruct {
|
||||
std::function<void(void*, size_t)> deleter;
|
||||
};
|
||||
|
||||
static void DeleterFunction(void* memory, size_t len, void* deleter_struct) {
|
||||
DeleterStruct* deleter = reinterpret_cast<DeleterStruct*>(deleter_struct);
|
||||
deleter->deleter(memory, len);
|
||||
delete deleter;
|
||||
}
|
||||
|
||||
struct TFTensorDeleter {
|
||||
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
|
||||
};
|
||||
@ -111,7 +144,32 @@ inline size_t Tensor::num_bytes() const {
|
||||
return TF_TensorByteSize(tensor_.get());
|
||||
}
|
||||
|
||||
inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
const std::vector<int64_t>& shape, void* data,
|
||||
size_t len, DeleterCallback deleter,
|
||||
Status* status) {
|
||||
// Credit to apassos@ for this technique:
|
||||
// Despite the fact that our API takes a std::function deleter, we are able
|
||||
// to maintain ABI stability because:
|
||||
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
|
||||
// 2. DeleterFunction is defined in the same build artifact that constructed
|
||||
// the std::function (so there isn't confusion about std::function ABI).
|
||||
// Note that 2. is satisifed by the fact that this is a header-only API, where
|
||||
// the function implementations are inline.
|
||||
|
||||
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
|
||||
TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len,
|
||||
&DeleterFunction, deleter_struct);
|
||||
if (tensor == nullptr) {
|
||||
status->SetStatus(TF_INVALID_ARGUMENT,
|
||||
"Failed to create tensor for input buffer");
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
|
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// An opaque representation of a tensor computed/managed by the Tensorflow
|
||||
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
|
||||
// to tensors placed in memory of different devices or remote address spaces.
|
||||
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
|
||||
// from it.
|
||||
class TensorHandle {
|
||||
public:
|
||||
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
|
||||
// status->ok() will be false, and the returned Tensor must not be used.
|
||||
Tensor Resolve(Status* status);
|
||||
|
||||
// Constructs a TensorHandle from a Tensor. If an error occurred,
|
||||
// status->ok() will be false, and the returned TensorHandle must not be used.
|
||||
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
|
||||
Status* status);
|
||||
|
||||
// TensorHandle is movable, and not copyable
|
||||
TensorHandle(TensorHandle&&) = default;
|
||||
TensorHandle& operator=(TensorHandle&&) = default;
|
||||
|
||||
private:
|
||||
// Wraps a TFE_TensorHandle. Takes ownership of handle.
|
||||
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
|
||||
|
||||
// TensorHandle is not copyable
|
||||
TensorHandle(const TensorHandle&) = delete;
|
||||
TensorHandle& operator=(const TensorHandle&) = delete;
|
||||
|
||||
// Returns the underlying TFE_TensorHandle that this object wraps.
|
||||
// This object retains ownership of the pointer.
|
||||
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
|
||||
|
||||
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
|
||||
// and takes ownership of handle.
|
||||
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
|
||||
|
||||
struct TFETensorHandleDeleter {
|
||||
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
|
||||
};
|
||||
|
||||
inline Tensor TensorHandle::Resolve(Status* status) {
|
||||
TF_Tensor* tensor =
|
||||
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
|
||||
const Runtime& runtime,
|
||||
Status* status) {
|
||||
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
|
||||
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return TensorHandle(nullptr);
|
||||
}
|
||||
return TensorHandle(tensor_handle);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
@ -0,0 +1,50 @@
|
||||
# Tests for the C++ header-only base types.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_types_test_util",
|
||||
testonly = True,
|
||||
hdrs = ["tensor_types_test_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_datatype",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensor_test",
|
||||
srcs = [
|
||||
"tensor_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensorhandle_test",
|
||||
srcs = [
|
||||
"tensorhandle_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/cc/experimental/base/public:tensorhandle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
@ -0,0 +1,163 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
bool done = false;
|
||||
Status status;
|
||||
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||
{
|
||||
// data_vector is a rank 1 tensor.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(data_vector.size());
|
||||
|
||||
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
done = true;
|
||||
};
|
||||
|
||||
Tensor tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
/*data=*/data_vector.data(),
|
||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||
/*deleter=*/callback, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
}
|
||||
// At this point, tensor has been destroyed, and the deleter callback should
|
||||
// have run.
|
||||
EXPECT_TRUE(done);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
using tensorflow::experimental::cc::TensorHandle;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
|
||||
// verify the expected dims, dtype, value, num bytes, and num elements.
|
||||
TYPED_TEST(ConstructScalarTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor original_tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -4,7 +4,6 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
"if_ios",
|
||||
"if_mobile",
|
||||
"if_not_mobile",
|
||||
"tf_cc_test",
|
||||
@ -85,7 +84,7 @@ cc_library(
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]) + if_android([
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// FunctionMetadata stores additional function information, including
|
||||
@ -40,6 +41,7 @@ class FunctionMetadata final {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
@ -26,10 +26,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::SavedModelAPI;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unordered_set<std::string> tags = {"serve"};
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
}
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
||||
std::vector<string> dim_vars;
|
||||
string dim_sizes, indices;
|
||||
int count = 1;
|
||||
if (shape.rank() == 0 ||
|
||||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
||||
dim_sizes = "[1]";
|
||||
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
||||
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
||||
indices += absl::StrCat("[dim", dim, "]");
|
||||
count *= shape.dimensions(dim);
|
||||
}
|
||||
}
|
||||
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
||||
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||
rewrites->push_back({"{{INDICES}}", indices});
|
||||
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int arg{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int arg{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.feed(i).name().empty()) {
|
||||
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
result_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int result{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int result{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.fetch(i).name().empty()) {
|
||||
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int var_{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int var_{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||
|
@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
arg_data(0)))[dim0][dim1];
|
||||
}
|
||||
int arg0_size() const {
|
||||
return 2 * sizeof(float);
|
||||
}
|
||||
int arg0_count() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
void set_arg_myfeed_data(const void* data) {
|
||||
set_arg_data(0, data);
|
||||
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
arg_data(0)))[dim0][dim1];
|
||||
}
|
||||
int arg_myfeed_size() const {
|
||||
return 2 * sizeof(float);
|
||||
}
|
||||
int arg_myfeed_count() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
void set_arg1_data(const void* data) {
|
||||
set_arg_data(1, data);
|
||||
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
||||
arg_data(1)))[dim0][dim1];
|
||||
}
|
||||
int arg1_size() const {
|
||||
return 12 * sizeof(tensorflow::int64);
|
||||
}
|
||||
int arg1_count() const {
|
||||
return 12;
|
||||
}
|
||||
|
||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||
// Must only be called after a successful Run call. There is a set of methods
|
||||
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
result_data(0)))[dim0][dim1];
|
||||
}
|
||||
int result0_size() const {
|
||||
return 30 * sizeof(tensorflow::uint32);
|
||||
}
|
||||
int result0_count() const {
|
||||
return 30;
|
||||
}
|
||||
|
||||
tensorflow::uint32* result_myfetch_data() {
|
||||
return static_cast<tensorflow::uint32*>(result_data(0));
|
||||
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
result_data(0)))[dim0][dim1];
|
||||
}
|
||||
int result_myfetch_size() const {
|
||||
return 30 * sizeof(tensorflow::uint32);
|
||||
}
|
||||
int result_myfetch_count() const {
|
||||
return 30;
|
||||
}
|
||||
|
||||
// Methods for managing variable buffers. Buffers are in row-major order.
|
||||
//
|
||||
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(2)))[0];
|
||||
}
|
||||
int var_myvar_readonly_size() const {
|
||||
return 1 * sizeof(float);
|
||||
}
|
||||
int var_myvar_readonly_count() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void set_var_myvar_data(float* data) {
|
||||
set_arg_data(3, data);
|
||||
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(3)))[0];
|
||||
}
|
||||
int var_myvar_size() const {
|
||||
return 1 * sizeof(float);
|
||||
}
|
||||
int var_myvar_count() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||
set_arg_data(4, data);
|
||||
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::int32(*)[5]>(
|
||||
arg_data(4)))[dim0];
|
||||
}
|
||||
int var_myvar2_size() const {
|
||||
return 5 * sizeof(tensorflow::int32);
|
||||
}
|
||||
int var_myvar2_count() const {
|
||||
return 5;
|
||||
}
|
||||
|
||||
private:
|
||||
// Number of buffers for the compiled computation.
|
||||
|
@ -20,7 +20,7 @@ load(
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
|
||||
|
||||
def tf_library(
|
||||
name,
|
||||
@ -42,7 +42,8 @@ def tf_library(
|
||||
mlir_components = "None",
|
||||
deps = None,
|
||||
tags = []):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
|
||||
math enabled on cpu.
|
||||
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
build targets:
|
||||
@ -187,7 +188,9 @@ def tf_library(
|
||||
# `find` on such an object.
|
||||
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
|
||||
|
||||
flags = tfcompile_extra_flags() + flags
|
||||
target_cpu = tfcompile_target_cpu()
|
||||
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
|
||||
flags = extra_flags + flags
|
||||
|
||||
if enable_xla_hlo_profiling:
|
||||
profiling_flag = "--xla_hlo_profile"
|
||||
@ -207,6 +210,15 @@ def tf_library(
|
||||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
default_fast_math_xla_flags = ("XLA_FLAGS='" +
|
||||
"--xla_cpu_enable_fast_math=true " +
|
||||
"--xla_cpu_fast_math_honor_nans=false " +
|
||||
"--xla_cpu_fast_math_honor_infs=false " +
|
||||
"--xla_cpu_fast_math_honor_functions=false " +
|
||||
"--xla_cpu_fast_math_honor_division=false " +
|
||||
"--xla_cpu_enable_fast_min_max=true " +
|
||||
"$${XLA_FLAGS:-}' ")
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = srcs,
|
||||
@ -216,6 +228,7 @@ def tf_library(
|
||||
function_object_file,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
@ -256,6 +269,7 @@ def tf_library(
|
||||
session_module_pb,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
|
@ -67,6 +67,8 @@ int main(int argc, char** argv) {
|
||||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
|
||||
// generally the preferred case.
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::AppendDebugOptionsFlags(&flag_list);
|
||||
|
@ -251,7 +251,7 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:graph",
|
||||
@ -505,6 +505,7 @@ cc_library(
|
||||
name = "shape_inference",
|
||||
srcs = ["shape_inference.cc"],
|
||||
hdrs = ["shape_inference.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":shape_inference_helpers",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -2034,6 +2034,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorArraySplitV3",
|
||||
"TensorArrayV3",
|
||||
"TensorArrayWriteV3",
|
||||
"TensorListConcatV2",
|
||||
"TensorListElementShape",
|
||||
"TensorListFromTensor",
|
||||
"TensorListGather",
|
||||
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorListPushBack",
|
||||
"TensorListReserve",
|
||||
"TensorListSetItem",
|
||||
"TensorListSplit",
|
||||
"TensorListStack",
|
||||
"TensorScatterAdd",
|
||||
"TensorScatterSub",
|
||||
|
@ -395,12 +395,11 @@ static void ShowXlaDeviceDeprecationWarning(
|
||||
if (absl::StrContains(compilation_device_name, "CPU") ||
|
||||
absl::StrContains(compilation_device_name, "GPU")) {
|
||||
absl::call_once(once, [] {
|
||||
LOG(WARNING)
|
||||
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
"removed in subsequent releases. Instead, use either "
|
||||
"@tf.function(experimental_compile=True) for must-compile "
|
||||
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
|
||||
"for auto-clustering best-effort compilation.";
|
||||
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
"removed in subsequent releases. Instead, use either "
|
||||
"@tf.function(experimental_compile=True) for must-compile "
|
||||
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
|
||||
"for auto-clustering best-effort compilation.";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -180,12 +180,10 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
data::MakeIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \
|
||||
data::DeleteIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \
|
||||
data::DeleteIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||
data::IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||
|
@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
}
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
SummarizeNodeDef(node_def), ".\n");
|
||||
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
|
@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
|
||||
arg_buffers_.resize(kernel->xla_input_shapes.size());
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
const Tensor* t;
|
||||
@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
<< " not the same as on-host shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(shape);
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = arg_buffers_[i].get();
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ class XlaComputationLaunchContext {
|
||||
se::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
bool use_multiple_streams_;
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
||||
std::deque<xla::ShapedBuffer> arg_buffers_;
|
||||
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
||||
};
|
||||
|
||||
|
@ -77,10 +77,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
|
||||
"//tensorflow/compiler/mlir/tfrt:lower_tf_to_tfd_alwayslink",
|
||||
"//tensorflow/compiler/mlir/tfrt:runtime_fallback_opdefs_alwayslink",
|
||||
"//tensorflow/compiler/mlir/tfrt:tf_legalize_to_tfrt",
|
||||
"//tensorflow/compiler/mlir/tfrt:tf_to_corert",
|
||||
],
|
||||
)
|
||||
|
||||
@ -108,6 +104,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -152,7 +149,6 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tfrt:compatibility_analysis",
|
||||
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -31,7 +31,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -216,13 +216,13 @@ cc_library(
|
||||
"ir/tfl_ops.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
@ -260,6 +260,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tftext_utils",
|
||||
srcs = [
|
||||
"utils/tftext_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/tftext_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stateful_ops_utils",
|
||||
srcs = [
|
||||
@ -320,6 +339,7 @@ cc_library(
|
||||
":lstm_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
":tftext_utils",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
@ -695,9 +715,9 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
|
@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||
|
||||
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
std::string node_def_str;
|
||||
if (!node_def.SerializeToString(&node_def_str)) {
|
||||
return emitError(loc, "failed to serialize tensorflow node_def"),
|
||||
llvm::None;
|
||||
}
|
||||
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
|
||||
return builder_.CreateVector(flex_builder->GetBuffer());
|
||||
}
|
||||
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
|
||||
size_t map_start = flex_builder->StartMap();
|
||||
for (const auto& pair : node_def.attr()) {
|
||||
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
|
||||
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
|
||||
std::sort(attrs.begin(), attrs.end(),
|
||||
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
|
||||
for (const Item& pair : attrs) {
|
||||
const char* key = pair.first.c_str();
|
||||
const auto& attr = pair.second;
|
||||
const ::tensorflow::AttrValue& attr = pair.second;
|
||||
switch (attr.value_case()) {
|
||||
case ::tensorflow::AttrValue::kS:
|
||||
flex_builder->String(key, attr.s());
|
||||
@ -1020,7 +1019,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
if (!inst->getMutableAttrDict().getAttrs().empty()) {
|
||||
os << " {";
|
||||
bool first = true;
|
||||
for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) {
|
||||
for (auto& named_attr : inst->getAttrDictionary()) {
|
||||
os << (!first ? ", " : "");
|
||||
first = false;
|
||||
named_attr.first.print(os);
|
||||
|
@ -1966,9 +1966,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
static LogicalResult Verify(TransposeOp op) {
|
||||
auto input_type = op.x().getType().cast<ShapedType>();
|
||||
auto input_type = op.input().getType().cast<ShapedType>();
|
||||
auto perm_type = op.perm().getType().cast<ShapedType>();
|
||||
auto output_type = op.y().getType().cast<ShapedType>();
|
||||
auto output_type = op.output().getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
|
||||
if (perm_type.getNumElements() != input_type.getRank()) {
|
||||
return op.emitOpError(
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
|
@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
|
||||
if (exported_names.size() != 1) {
|
||||
return errors::Unimplemented("Only support a single exported name.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ImportSavedModel(model_flags.saved_model_dir(),
|
||||
model_flags.saved_model_version(), tags,
|
||||
|
@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_STRING;
|
||||
case toco::IODataType::BOOL:
|
||||
return DT_BOOL;
|
||||
case toco::IODataType::COMPLEX64:
|
||||
return DT_COMPLEX64;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
@ -175,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
|
||||
return RegisterCustomBuiltinOps(extra_tf_opdefs);
|
||||
}
|
||||
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs) {
|
||||
Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
|
||||
tensorflow::DataType inference_type =
|
||||
@ -209,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
flag.shape().dims().end()));
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
|
||||
inference_type));
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
if (flag.has_mean_value() && flag.has_std_value()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(),
|
||||
flag.std_value(), inference_type));
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
} else {
|
||||
node_mins->push_back(llvm::None);
|
||||
node_maxs->push_back(llvm::None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -252,7 +258,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
std::string error_message;
|
||||
auto output = mlir::openOutputFile(filename, &error_message);
|
||||
if (!error_message.empty()) {
|
||||
return errors::InvalidArgument("Failed to open file in %s.", filename);
|
||||
return errors::InvalidArgument("Failed to open file in ", filename);
|
||||
}
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
|
||||
|
@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
|
||||
|
||||
// Populate quantization specs (or not) given user specified ranges for each
|
||||
// input arrays.
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs);
|
||||
Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs);
|
||||
|
||||
// Convert imported MLIR file to TfLite flatbuffer.
|
||||
// This will also run relevant passes as well.
|
||||
|
@ -3,6 +3,10 @@ load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library",
|
||||
)
|
||||
load(
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -23,6 +27,7 @@ package_group(
|
||||
exports_files([
|
||||
"quantization_traits.h",
|
||||
"quantization_config.h",
|
||||
"quantization_utils.h",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
@ -34,6 +39,25 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "quantization_interfaces_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-interface-decls",
|
||||
"quantization_interface.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-interface-defs",
|
||||
"quantization_interface.cc.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "quantization.td",
|
||||
td_srcs = [
|
||||
":quantization_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "quantization_info_proto",
|
||||
srcs = [
|
||||
@ -71,9 +95,11 @@ cc_library(
|
||||
name = "quantization_lib",
|
||||
srcs = [
|
||||
"quantization_driver.cc",
|
||||
"quantization_interface.cc.inc",
|
||||
"quantization_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantization_interface.h.inc",
|
||||
"quantization_traits.h",
|
||||
"quantization_utils.h",
|
||||
],
|
||||
|
@ -49,14 +49,16 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"tfl_to_std.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lite {
|
||||
@ -38,6 +39,7 @@ namespace lite {
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const tflite::TensorType& inference_type,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
|
||||
// Apply quantization passes
|
||||
PassManager pm(module->getContext());
|
||||
TFL::QuantizationSpecs quant_specs;
|
||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
|
||||
quant_specs.post_training_quantization = true;
|
||||
quant_specs.disable_per_channel = disable_per_channel;
|
||||
|
||||
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
|
||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||
if (input_tf_type == tensorflow::DT_FLOAT) {
|
||||
emit_adaptor = true;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
||||
quant_specs.inference_type = tensorflow::DT_QUINT8;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8 ||
|
||||
input_tf_type == tensorflow::DT_INT8 ||
|
||||
input_tf_type == tensorflow::DT_INT16) {
|
||||
quant_specs.inference_type = input_tf_type;
|
||||
}
|
||||
|
||||
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
|
@ -26,11 +26,13 @@ namespace mlir {
|
||||
namespace lite {
|
||||
|
||||
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
|
||||
// The `input_type` and `output_type` can be float32/qint8/int8.
|
||||
// The `input_type`, `output_type` and `inference_type` can be
|
||||
// float32/qint8/int8/int16.
|
||||
// Return partially quantized model if `fully_quantize` is false.
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const tflite::TensorType& inference_type,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
|
@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
|
||||
|
||||
tflite::StderrReporter error_reporter;
|
||||
return mlir::lite::QuantizeModel(
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8,
|
||||
tflite::TensorType_INT8, {},
|
||||
/*disable_per_channel=*/false,
|
||||
/*fully_quantize=*/true, builder, &error_reporter);
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
@ -47,12 +48,18 @@ void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
|
||||
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
|
||||
dq.arg());
|
||||
dq.getResult().replaceAllUsesWith(dcast);
|
||||
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
|
||||
dcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
|
||||
}
|
||||
dq.erase();
|
||||
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
|
||||
auto out_type = q.getResult().getType();
|
||||
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
|
||||
TypeAttr::get(out_type));
|
||||
q.getResult().replaceAllUsesWith(qcast);
|
||||
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
|
||||
qcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
|
||||
}
|
||||
q.erase();
|
||||
}
|
||||
});
|
||||
|
@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>;
|
||||
// https://www.tensorflow.org/lite/performance/quantization_spec
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/157870442): replace all FixedResultScale trait
|
||||
def FixedOutputRangeInterface : OpInterface<
|
||||
"FixedOutputRangeInterface"> {
|
||||
let description = [{
|
||||
Interface for defining the fixed output range.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the fixed output range.}],
|
||||
"UniformQuantizedType", "GetFixedOutputRange",
|
||||
(ins "bool":$sign, "int":$bit_width)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
// Specify this trait if the op has a fixed output value range.
|
||||
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
|
||||
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
|
||||
|
@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
absl::string_view inference_type,
|
||||
QuantizationSpecs* quant_specs) {
|
||||
std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
|
||||
std::vector<double> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
if (!min_values.empty()) {
|
||||
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
|
||||
for (int i = 0; i < node_mins_str.size(); i++) {
|
||||
@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
if (!max_values.empty()) {
|
||||
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
|
||||
for (int i = 0; i < node_maxs_str.size(); i++) {
|
||||
@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
quant_specs);
|
||||
}
|
||||
|
||||
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
|
||||
const std::vector<double>& node_mins,
|
||||
const std::vector<double>& node_maxs,
|
||||
tensorflow::DataType inference_type,
|
||||
QuantizationSpecs* quant_specs) {
|
||||
bool GetInputNodeQuantSpecs(
|
||||
const std::vector<std::string>& node_names,
|
||||
const std::vector<llvm::Optional<double>>& node_mins,
|
||||
const std::vector<llvm::Optional<double>>& node_maxs,
|
||||
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
|
||||
quant_specs->inference_type = inference_type;
|
||||
|
||||
// If min/max are not specified, just return;
|
||||
|
@ -69,7 +69,8 @@ struct QuantizationSpecs {
|
||||
// arguments. They are only used when `weight_quantization` is set to false,
|
||||
// and the model is required to have quantization parameters, either from
|
||||
// quantization aware training or calibration, for the remaining tensors.
|
||||
std::vector<std::pair<double, double>> input_ranges;
|
||||
std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
|
||||
input_ranges;
|
||||
|
||||
// The default ranges can be used when a tensor doesn't have quantization
|
||||
// parameters and couldn't be quantized. Used only for latency tests.
|
||||
@ -90,7 +91,7 @@ struct QuantizationSpecs {
|
||||
bool RunWeightQuantization() const { return weight_quantization; }
|
||||
|
||||
// Whether this inference type represents a signed storage type.
|
||||
bool IsSignedInferenceType() {
|
||||
bool IsSignedInferenceType() const {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QUINT8:
|
||||
case tensorflow::DT_QUINT16:
|
||||
@ -102,7 +103,7 @@ struct QuantizationSpecs {
|
||||
|
||||
// Gets the width of this quantization type. Returns 0 if it isn't a
|
||||
// quantization type.
|
||||
int64_t GetQuantizationTypeWidth() {
|
||||
int64_t GetQuantizationTypeWidth() const {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QINT8:
|
||||
case tensorflow::DT_QUINT8:
|
||||
@ -130,11 +131,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
// Gets the quantization specification for input arrays. The array names are not
|
||||
// stored in the spec, and will be matched by position. The min/max will be
|
||||
// ignored if the inference_type isn't a quantized type. Returns true if failed.
|
||||
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
|
||||
const std::vector<double>& node_mins,
|
||||
const std::vector<double>& node_maxs,
|
||||
tensorflow::DataType inference_type,
|
||||
QuantizationSpecs* quant_specs);
|
||||
bool GetInputNodeQuantSpecs(
|
||||
const std::vector<std::string>& node_names,
|
||||
const std::vector<llvm::Optional<double>>& node_mins,
|
||||
const std::vector<llvm::Optional<double>>& node_maxs,
|
||||
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
@ -494,6 +494,13 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
|
||||
auto dequantize = builder_.create<quant::DequantizeCastOp>(
|
||||
loc, expressed_type, quantize.getResult());
|
||||
|
||||
// This attribute is set to distinguish the quantize ops being added by the
|
||||
// quantization pass. These ops can be removed without losing original
|
||||
// program accuracy.
|
||||
// TODO(fengliuai): make the attribute being part of op definition.
|
||||
quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
|
||||
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
|
@ -21,13 +21,18 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace quant {
|
||||
|
||||
using QuantizedType = mlir::quant::QuantizedType;
|
||||
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// This includes the interface class definition. It couldn't be in a namespace
|
||||
// because the table gen doesn't emit the namespace when it is used.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc"
|
||||
|
||||
namespace OpTrait {
|
||||
namespace quant {
|
||||
|
||||
// The base class that all the quantization related OpTrait implements.
|
||||
template <typename ConcreteType, template <typename> class TraitType>
|
||||
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
@ -436,6 +437,16 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
llvm::SmallVector<quant::StatisticsOp, 16> all_stats_ops;
|
||||
llvm::DenseSet<Operation*> redundant_stats_ops;
|
||||
|
||||
// Step 0: remove the quant::StatisticsOp which are used by the tfl.quantize
|
||||
// op in case it overrides the information from training FakeQuant ops.
|
||||
func.walk([&](quant::QuantizeCastOp q) {
|
||||
auto input_op = q.arg().getDefiningOp();
|
||||
if (auto stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(input_op)) {
|
||||
q.setOperand(stats.arg());
|
||||
if (stats.use_empty()) stats.erase();
|
||||
}
|
||||
});
|
||||
|
||||
// Step 1: forward pass: propagate any value scales which are not produces
|
||||
// by `SameOperandsAndResultsScale`. Additionally, remove the value scales
|
||||
// which are produced by the `restricted_output_params`.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user