Merge branch 'master' into master

This commit is contained in:
ThisIsIsaac 2019-05-16 14:45:51 +09:00 committed by GitHub
commit 2eb90433f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3239 changed files with 167100 additions and 114468 deletions

View File

@ -18,10 +18,11 @@ about: Use this template for reporting a bug or a performance issue.
- CUDA/cuDNN version: - CUDA/cuDNN version:
- GPU model and memory: - GPU model and memory:
You can collect some of this information using our environment capture
You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior** **Describe the current behavior**

View File

@ -1,17 +1,55 @@
--- ---
name: Documentation Issue name: Documentation Issue
about: Use this template for documentation related issues about: Use this template for documentation related
labels: 'type:docs'
--- ---
<em>Please make sure that this is a documentation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:doc_template</em> Thank you for submitting a TensorFlow documentation issue. Per our GitHub
policy, we only address code/doc bugs, performance issues, feature requests, and
build/installation issues on GitHub.
The TensorFlow docs are open source! To get involved, read the documentation
contributor guide: https://www.tensorflow.org/community/contribute/docs
**System information** ## URL(s) with the issue:
- TensorFlow version:
- Doc Link:
Please provide a link to the documentation entry, for example:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/MyMethod
**Describe the documentation issue** ## Description of issue (what needs changing):
**We welcome contributions by users. Will you be able to update submit a PR (use the [doc style guide](https://www.tensorflow.org/community/documentation)) to fix the doc Issue?** ### Clear description
For example, why should someone use this method? How is it useful?
### Correct links
Is the link to the source code correct?
### Parameters defined
Are all parameters defined and formatted correctly?
### Returns defined
Are return values defined?
### Raises listed and defined
Are the errors defined? For example,
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/feature_column/categorical_column_with_vocabulary_file#raises
### Usage example
Is there a usage example?
### Request visuals, if applicable
Are there currently visuals? If not, will it clarify the content?
### Submit a pull request?
Are you planning to also submit a pull request to fix the issue? See the docs
contributor guide: https://www.tensorflow.org/community/contribute/docs and the
docs style guide: https://www.tensorflow.org/community/contribute/docs_style

21
.gitignore vendored
View File

@ -20,18 +20,8 @@ tensorflow/contrib/cmake/_build/
[Bb]uild/ [Bb]uild/
/tensorflow/core/util/version_info.cc /tensorflow/core/util/version_info.cc
/tensorflow/python/framework/fast_tensor_util.cpp /tensorflow/python/framework/fast_tensor_util.cpp
Pods
Podfile.lock
*.pbxproj
*.xcworkspacedata
/*.podspec
/tensorflow/lite/experimental/objc/BUILD
/tensorflow/lite/experimental/swift/BUILD
/tensorflow/lite/examples/ios/simple/data/*.txt
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/gen/** /tensorflow/lite/gen/**
/tensorflow/lite/tools/make/downloads/** /tensorflow/lite/tools/make/downloads/**
xcuserdata/**
/api_init_files_list.txt /api_init_files_list.txt
/estimator_api_init_files_list.txt /estimator_api_init_files_list.txt
*.whl *.whl
@ -42,3 +32,14 @@ xcuserdata/**
*.iml *.iml
local.properties local.properties
gradleBuild gradleBuild
# iOS
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/examples/ios/simple/data/*.txt
Podfile.lock
Pods
xcuserdata

View File

@ -85,7 +85,7 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's
uphold this code.** uphold this code.**
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs, so please see tracking requests and bugs, please see
[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) [TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss)
for general questions and discussion, and please direct specific questions to for general questions and discussion, and please direct specific questions to
[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
@ -114,15 +114,16 @@ The TensorFlow project strives to abide by generally accepted best practices in
### Community Supported Builds ### Community Supported Builds
Build Type | Status | Artifacts Build Type | Status | Artifacts
-------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- --------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) **Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) **Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5 and 3.6** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/) **Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, and 3.6** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 pypi](https://tensorflow.pypi.thoth-station.ninja/index/)
## For more information ## For more information

View File

@ -1,3 +1,10 @@
# Release 1.12.2
## Bug Fixes and Other Changes
* Fixes a potential security vulnerability where carefully crafted GIF images
can produce a null pointer dereference during decoding.
# Release 1.13.0 # Release 1.13.0
## Major Features and Improvements ## Major Features and Improvements
@ -14,98 +21,185 @@
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* Documentation * Documentation
* Update the doc with the details about the rounding mode used in quantize_and_dequantize_v2. * Update the doc with the details about the rounding mode used in
* Clarify that tensorflow::port::InitMain() _should_ be called before using the TensorFlow library. Programs failing to do this are not portable to all platforms. quantize_and_dequantize_v2.
* Deprecations and Symbol renames. * Clarify that tensorflow::port::InitMain() _should_ be called before
* Removing deprecations for the following endpoints: `tf.acos`, `tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`, `tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`, `tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`, `tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`, `tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`, `tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan` using the TensorFlow library. Programs failing to do this are not
* Deprecate `tf.data.Dataset.shard`. portable to all platforms.
* Deprecate `saved_model.loader.load` which is replaced by `saved_model.load` and `saved_model.main_op`, which will be replaced by `saved_model.main_op` in V2. * Deprecations and Symbol renames.
* Deprecate tf.QUANTIZED_DTYPES. The official new symbol is tf.dtypes.QUANTIZED_DTYPES. * Removing deprecations for the following endpoints: `tf.acos`,
* Update sklearn imports for deprecated packages. `tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`,
* Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of `Dataset.range`. `tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`,
* Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of `tf.train.confusion_matrix`. `tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`,
* Add `tf.dtypes.` endpoint for every constant in dtypes.py; moving endpoints in versions.py to corresponding endpoints in `tf.sysconfig.` and `tf.version.`; moving all constants under `tf.saved_model` submodules to `tf.saved_model` module. New endpoints are added in V1 and V2 but existing endpoint removals are only applied in V2. `tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`,
* Deprecates behavior where device assignment overrides collocation constraints inside a collocation context manager. `tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`,
* Keras & Python API `tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan`
* Add to Keras functionality analogous to `tf.register_tensor_conversion_function`. * Deprecate `tf.data.Dataset.shard`.
* Subclassed Keras models can now be saved through `tf.contrib.saved_model.save_keras_model`. * Deprecate `saved_model.loader.load` which is replaced by
* `LinearOperator.matmul` now returns a new `LinearOperator`. `saved_model.load` and `saved_model.main_op`, which will be replaced by
* New ops and improved op functionality `saved_model.main_op` in V2.
* Add a Nearest Neighbor Resize op. * Deprecate tf.QUANTIZED_DTYPES. The official new symbol is
* Add an `ignore_unknown` argument to `parse_values` which suppresses ValueError for unknown hyperparameter types. Such * Add `tf.linalg.matvec` convenience function. tf.dtypes.QUANTIZED_DTYPES.
* `tf.einsum()`raises `ValueError` for unsupported equations like `"ii->"`. * Update sklearn imports for deprecated packages.
* Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`. * Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of
* Add LU decomposition op. `Dataset.range`.
* Add quantile loss to gradient boosted trees in estimator. * Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of
* Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding algorithm. `tf.train.confusion_matrix`.
* Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`, `unicode_split`, `unicode_split_with_offset`, and `unicode_transcode` ops. Amongst other things, this Op adds the ability to encode, decode, and transcode a variety of input text encoding formats into the main Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE) * Add `tf.dtypes.` endpoint for every constant in dtypes.py. Moving
* Add "unit" attribute to the substr op, which allows obtaining the substring of a string containing unicode characters. endpoints in versions.py to corresponding endpoints in `tf.sysconfig.`
* Broadcasting support for Ragged Tensors. and `tf.version.`. Moving all constants under `tf.saved_model`
* `SpaceToDepth` supports uint8 data type. submodules to `tf.saved_model` module. New endpoints are added in V1 and
* Support multi-label quantile regression in estimator. V2 but existing endpoint removals are only applied in V2.
* We now use "div" as the default partition_strategy in `tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and `tf.nn.nce_loss`. * Deprecates behavior where device assignment overrides collocation
hyperparameter are ignored. constraints inside a collocation context manager.
* Performance * Keras & Python API
* Improve performance of GPU cumsum/cumprod by up to 300x. * Add to Keras functionality analogous to
* Added support for weight decay in most TPU embedding optimizers, including AdamW and MomentumW. `tf.register_tensor_conversion_function`.
* TensorFlow 2.0 Development * Subclassed Keras models can now be saved through
* Add a command line tool to convert to TF2.0, tf_upgrade_v2 `tf.contrib.saved_model.save_keras_model`.
* Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0. * `LinearOperator.matmul` now returns a new `LinearOperator`.
* Change the default recurrent activation function for LSTM from 'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is 'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we change the default for CPU mode to sigmoid as well. With that, the default LSTM will be compatible with both CPU and GPU kernel. This will enable user with GPU to use CuDNN kernel by default and get a 10x performance boost in training. Note that this is checkpoint breaking change. If user want to use their 1.x pre-trained checkpoint, please construct the layer with LSTM(recurrent_activation='hard_sigmoid') to fallback to 1.x behavior. * New ops and improved op functionality
* TensorFlow Lite * Add a Nearest Neighbor Resize op.
* Move from `tensorflow/contrib/lite` to `tensorflow/lite`. * Add an `ignore_unknown` argument to `parse_values` which suppresses
* Add experimental Java API for injecting TensorFlow Lite delegates ValueError for unknown hyperparameter types. Such * Add
* Add support for strings in TensorFlow Lite Java API. `tf.linalg.matvec` convenience function.
* `tf.contrib`: * `tf.einsum()`raises `ValueError` for unsupported equations like
* Add Apache Ignite Filesystem plugin to support accessing Apache IGFS. `"ii->"`.
* Dropout now takes `rate` argument, `keep_prob` is deprecated. * Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`.
* Estimator occurrences references `tf.contrib.estimator` were changed to `tf.estimator`: * Add LU decomposition op.
* `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator` * Add quantile loss to gradient boosted trees in estimator.
* `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator` * Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding
* `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator` algorithm.
* `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator` * Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`,
* `tf.contrib.estimator.InMemoryEvaluatorHook` and tf.estimator.experimental.InMemoryEvaluatorHook`. `unicode_split`, `unicode_split_with_offset`, and `unicode_transcode`
* `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.make_stop_at_checkpoint_step_hook`. ops. Amongst other things, this Op adds the ability to encode, decode,
* Expose `tf.distribute.Strategy as the new name for tf.contrib.distribute.DistributionStrategy. and transcode a variety of input text encoding formats into the main
* Migrate linear optimizer from contrib to core. Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE)
* Move `tf.contrib.signal` to `tf.signal` (preserving aliases in tf.contrib.signal). * Add "unit" attribute to the substr op, which allows obtaining the
* Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`. substring of a string containing unicode characters.
* tf.data: * Broadcasting support for Ragged Tensors.
* Add `tf.data.experimental.StatsOptions()`, to configure options to collect statistics from `tf.data.Dataset` pipeline using `StatsAggregator`. Add nested option, `experimental_stats` (which takes a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`. Deprecates `tf.data.experimental.set_stats_agregator`. * `SpaceToDepth` supports uint8 data type.
* Performance optimizations: * Support multi-label quantile regression in estimator.
* Add `tf.data.experimental.OptimizationOptions()`, to configure options to enable `tf.data` performance optimizations. Add nested option, `experimental_optimization` (which takes a `tf.data.experimental.OptimizationOptions` object), to `tf.data.Options`. Remove performance optimization options from `tf.data.Options`, and add them under `tf.data.experimental.OptimizationOptions` instead. * We now use "div" as the default partition_strategy in
* Enable `map_and_batch_fusion` and `noop_elimination` optimizations by default. They can be disabled by configuring `tf.data.experimental.OptimizationOptions` to set `map_and_batch = False` or `noop_elimination = False` respectively. To disable all default optimizations, set `apply_default_optimizations = False`. `tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and
* Support parallel map in `map_and_filter_fusion`. `tf.nn.nce_loss`. hyperparameter are ignored.
* Disable static optimizations for input pipelines that use non-resource `tf.Variable`s. * Performance
* Add NUMA-aware MapAndBatch dataset. * Improve performance of GPU cumsum/cumprod by up to 300x.
* Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it from V2, and added tf.compat.v1.data.make_one_shot_iterator()`. * Added support for weight decay in most TPU embedding optimizers,
* Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`. including AdamW and MomentumW.
* Enable nested dataset support in core `tf.data` transformations. * TensorFlow 2.0 Development
* For `tf.data.Dataset` implementers: Added `tf.data.Dataset._element_structured property` to replace `Dataset.output_{types,shapes,classes}`. * Add a command line tool to convert to TF2.0, tf_upgrade_v2
* Make `num_parallel_calls` of `tf.data.Dataset.interleave` and `tf.data.Dataset.map` work in Eager mode. * Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0.
* Toolchains * Change the default recurrent activation function for LSTM from
* Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`. 'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is
* Added bounds checking to printing deprecation warnings. 'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend
* Upgraded CUDA dependency to 10.0 between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we
* To build with Android NDK r14b, add "#include <linux/compiler.h>" to android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h change the default for CPU mode to sigmoid as well. With that, the
* Removed `:android_tensorflow_lib_selective_registration*` targets, use `:android_tensorflow_lib_lite*` targets instead. default LSTM will be compatible with both CPU and GPU kernel. This will
* XLA enable user with GPU to use CuDNN kernel by default and get a 10x
* Move `RoundToEven` function to xla/client/lib/math.h. performance boost in training. Note that this is checkpoint breaking
* A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1" or "true" allows the debug options passed within an XRTCompile op to be passed directly to the XLA compilation backend. If such variable is not set (service side), only a restricted set will be passed through. change. If user want to use their 1.x pre-trained checkpoint, please
* Allow the XRTCompile op to return the ProgramShape resulted form the XLA compilation as a second return argument. construct the layer with LSTM(recurrent_activation='hard_sigmoid') to
* XLA HLO graphs can now be rendered as SVG/HTML. fallback to 1.x behavior.
* Estimator * TensorFlow Lite
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator` * Move from `tensorflow/contrib/lite` to `tensorflow/lite`.
* Replace all occurences of `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator` * Add experimental Java API for injecting TensorFlow Lite delegates
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator` * Add support for strings in TensorFlow Lite Java API.
* Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator` * `tf.contrib`:
* Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`. * Add Apache Ignite Filesystem plugin to support accessing Apache IGFS.
* Update `regression_head` to the new Head API for Canned Estimator V2. * Dropout now takes `rate` argument, `keep_prob` is deprecated.
* Switch `multi_class_head` to Head API for Canned Estimator V2. * Estimator occurrences references `tf.contrib.estimator` were changed to
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook` and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.InMemoryEvaluatorHook` and `tf.estimator.experimental.make_stop_at_checkpoint_step_hook` `tf.estimator`:
* Migrate linear optimizer from contrib to core. * `tf.contrib.estimator.BaselineEstimator` with
`tf.estimator.BaselineEstimator`
* `tf.contrib.estimator.DNNLinearCombinedEstimator` with
`tf.estimator.DNNLinearCombinedEstimator`
* `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator`
* `tf.contrib.estimator.LinearEstimator` with
`tf.estimator.LinearEstimator`
* `tf.contrib.estimator.InMemoryEvaluatorHook` and
tf.estimator.experimental.InMemoryEvaluatorHook`.
* `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`.
* Expose `tf.distribute.Strategy as the new name for
tf.contrib.distribute.DistributionStrategy.
* Migrate linear optimizer from contrib to core.
* Move `tf.contrib.signal` to `tf.signal` (preserving aliases in
tf.contrib.signal).
* Users of `tf.contrib.estimator.export_all_saved_models` and related
should switch to
`tf.estimator.Estimator.experimental_export_all_saved_models`.
* tf.data:
* Add `tf.data.experimental.StatsOptions()`, to configure options to
collect statistics from `tf.data.Dataset` pipeline using
`StatsAggregator`. Add nested option, `experimental_stats` (which takes
a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`.
Deprecates `tf.data.experimental.set_stats_agregator`.
* Performance optimizations:
* Add `tf.data.experimental.OptimizationOptions()`, to configure options
to enable `tf.data` performance optimizations. Add nested option,
`experimental_optimization` (which takes a
`tf.data.experimental.OptimizationOptions` object), to
`tf.data.Options`. Remove performance optimization options from
`tf.data.Options`, and add them under
`tf.data.experimental.OptimizationOptions` instead.
* Enable `map_and_batch_fusion` and `noop_elimination` optimizations by
default. They can be disabled by configuring
`tf.data.experimental.OptimizationOptions` to set `map_and_batch =
False` or `noop_elimination = False` respectively. To disable all
default optimizations, set `apply_default_optimizations = False`.
* Support parallel map in `map_and_filter_fusion`.
* Disable static optimizations for input pipelines that use non-resource
`tf.Variable`s.
* Add NUMA-aware MapAndBatch dataset.
* Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it
from V2, and added tf.compat.v1.data.make_one_shot_iterator()`.
* Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed
it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`.
* Enable nested dataset support in core `tf.data` transformations.
* For `tf.data.Dataset` implementers: Added
`tf.data.Dataset._element_structured property` to replace
`Dataset.output_{types,shapes,classes}`.
* Make `num_parallel_calls` of `tf.data.Dataset.interleave` and
`tf.data.Dataset.map` work in Eager mode.
* Toolchains
* Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`.
* Added bounds checking to printing deprecation warnings.
* Upgraded CUDA dependency to 10.0
* To build with Android NDK r14b, add "#include <linux/compiler.h>" to
android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h
* Removed `:android_tensorflow_lib_selective_registration*` targets, use
`:android_tensorflow_lib_lite*` targets instead.
* XLA
* Move `RoundToEven` function to xla/client/lib/math.h.
* A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1"
or "true" allows the debug options passed within an XRTCompile op to be
passed directly to the XLA compilation backend. If such variable is not
set (service side), only a restricted set will be passed through.
* Allow the XRTCompile op to return the ProgramShape resulted form the XLA
compilation as a second return argument.
* XLA HLO graphs can now be rendered as SVG/HTML.
* Estimator
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with
`tf.estimator.BaselineEstimator`
* Replace all occurences of
`tf.contrib.estimator.DNNLinearCombinedEstimator` with
`tf.estimator.DNNLinearCombinedEstimator`
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with
`tf.estimator.DNNEstimator`
* Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with
`tf.estimator.LinearEstimator`
* Users of `tf.contrib.estimator.export_all_saved_models` and related
should switch to
`tf.estimator.Estimator.experimental_export_all_saved_models`.
* Update `regression_head` to the new Head API for Canned Estimator V2.
* Switch `multi_class_head` to Head API for Canned Estimator V2.
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook`
and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
`tf.estimator.experimental.InMemoryEvaluatorHook` and
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`
* Migrate linear optimizer from contrib to core.
## Thanks to our Contributors ## Thanks to our Contributors

View File

@ -43,8 +43,8 @@ remote_config_workspace()
# Apple and Swift rules. # Apple and Swift rules.
http_archive( http_archive(
name = "build_bazel_rules_apple", name = "build_bazel_rules_apple",
sha256 = "8f32e2839fba28d549e1670dbed83606dd339a9f7489118e481814d61738270f", sha256 = "23792cd999f97fc97284d1c44cb1324bfdd0bc54aa68ad513fa3705aca3b1f9e",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.14.0/rules_apple.0.14.0.tar.gz"], urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.15.0/rules_apple.0.15.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases ) # https://github.com/bazelbuild/rules_apple/releases
http_archive( http_archive(
name = "build_bazel_apple_support", name = "build_bazel_apple_support",
@ -58,14 +58,14 @@ http_archive(
) # https://github.com/bazelbuild/bazel-skylib/releases ) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive( http_archive(
name = "build_bazel_rules_swift", name = "build_bazel_rules_swift",
sha256 = "31aad005a9c4e56b256125844ad05eb27c88303502d74138186f9083479f93a6", sha256 = "9efe9699e9765e6b4a5e063e4a08f6b163cccaf0443f775d935baf5c3cd6ed0e",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.8.0/rules_swift.0.8.0.tar.gz"], urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.9.0/rules_swift.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases ) # https://github.com/bazelbuild/rules_swift/releases
http_archive( http_archive(
name = "com_github_apple_swift_swift_protobuf", name = "com_github_apple_swift_swift_protobuf",
type = "zip", type = "zip",
strip_prefix = "swift-protobuf-1.4.0/", strip_prefix = "swift-protobuf-1.5.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.4.0.zip"], urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"],
) # https://github.com/apple/swift-protobuf/releases ) # https://github.com/apple/swift-protobuf/releases
http_file( http_file(
name = "xctestrunner", name = "xctestrunner",

View File

@ -293,9 +293,9 @@ def get_var(environ_cp,
Args: Args:
environ_cp: copy of the os.environ. environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "Hadoop File query_item: string for feature related to the variable, e.g. "CUDA for
System". Nvidia GPUs".
enabled_by_default: boolean for default behavior. enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input. question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled. yes_reply: optional string for reply when feature is enabled.
@ -376,9 +376,9 @@ def set_build_var(environ_cp,
Args: Args:
environ_cp: copy of the os.environ. environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "Hadoop File query_item: string for feature related to the variable, e.g. "CUDA for
System". Nvidia GPUs".
option_name: string for option to define in .bazelrc. option_name: string for option to define in .bazelrc.
enabled_by_default: boolean for default behavior. enabled_by_default: boolean for default behavior.
bazel_config_name: Name for Bazel --config argument to enable build feature. bazel_config_name: Name for Bazel --config argument to enable build feature.
@ -411,9 +411,9 @@ def set_action_env_var(environ_cp,
Args: Args:
environ_cp: copy of the os.environ. environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "Hadoop File query_item: string for feature related to the variable, e.g. "CUDA for
System". Nvidia GPUs".
enabled_by_default: boolean for default behavior. enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input. question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled. yes_reply: optional string for reply when feature is enabled.
@ -456,8 +456,8 @@ def check_bazel_version(min_version, max_version):
"""Check installed bazel version is between min_version and max_version. """Check installed bazel version is between min_version and max_version.
Args: Args:
min_version: string for minimum bazel version. min_version: string for minimum bazel version (must exist!).
max_version: string for maximum bazel version. max_version: string for maximum bazel version (must exist!).
Returns: Returns:
The bazel version detected. The bazel version detected.
@ -570,7 +570,7 @@ def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
Args: Args:
environ_cp: copy of the os.environ. environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
ask_for_var: string for how to ask for user input. ask_for_var: string for how to ask for user input.
var_default: default value string. var_default: default value string.
@ -1039,10 +1039,8 @@ def set_other_cuda_vars(environ_cp):
# If CUDA is enabled, always use GPU during build and test. # If CUDA is enabled, always use GPU during build and test.
if environ_cp.get('TF_CUDA_CLANG') == '1': if environ_cp.get('TF_CUDA_CLANG') == '1':
write_to_bazelrc('build --config=cuda_clang') write_to_bazelrc('build --config=cuda_clang')
write_to_bazelrc('test --config=cuda_clang')
else: else:
write_to_bazelrc('build --config=cuda') write_to_bazelrc('build --config=cuda')
write_to_bazelrc('test --config=cuda')
def set_host_cxx_compiler(environ_cp): def set_host_cxx_compiler(environ_cp):
@ -1261,7 +1259,8 @@ def set_windows_build_flags(environ_cp):
write_to_bazelrc('build --copt=-w --host_copt=-w') write_to_bazelrc('build --copt=-w --host_copt=-w')
# Fix winsock2.h conflicts # Fix winsock2.h conflicts
write_to_bazelrc( write_to_bazelrc(
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN') 'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
'--copt=-DNOGDI --host_copt=-DNOGDI')
# Output more verbose information when something goes wrong # Output more verbose information when something goes wrong
write_to_bazelrc('build --verbose_failures') write_to_bazelrc('build --verbose_failures')
# The host and target platforms are the same in Windows build. So we don't # The host and target platforms are the same in Windows build. So we don't
@ -1324,9 +1323,9 @@ def validate_cuda_config(environ_cp):
cuda_libraries = ['cuda', 'cudnn'] cuda_libraries = ['cuda', 'cudnn']
if is_linux(): if is_linux():
if 'TF_TENSORRT_VERSION' in environ_cp: # if env variable exists if int(environ_cp.get('TF_NEED_TENSORRT', False)):
cuda_libraries.append('tensorrt') cuda_libraries.append('tensorrt')
if environ_cp.get('TF_NCCL_VERSION', None): # if env variable not empty if environ_cp.get('TF_NCCL_VERSION', None):
cuda_libraries.append('nccl') cuda_libraries.append('nccl')
proc = subprocess.Popen( proc = subprocess.Popen(
@ -1387,7 +1386,7 @@ def main():
# environment variables. # environment variables.
environ_cp = dict(os.environ) environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version('0.24.1', '0.25.0') current_bazel_version = check_bazel_version('0.24.1', '0.25.2')
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc() reset_tf_configure_bazelrc()
@ -1453,8 +1452,12 @@ def main():
cuda_env_names = [ cuda_env_names = [
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION', 'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS', 'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
'CUDA_TOOLKIT_PATH' # Items below are for backwards compatibility when not using
# TF_CUDA_PATHS.
'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH',
'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH'
] ]
# Note: set_action_env_var above already writes to bazelrc.
for name in cuda_env_names: for name in cuda_env_names:
if name in environ_cp: if name in environ_cp:
write_action_env_to_bazelrc(name, environ_cp[name]) write_action_env_to_bazelrc(name, environ_cp[name])
@ -1493,7 +1496,6 @@ def main():
else: else:
# Use downloaded LLD for linking. # Use downloaded LLD for linking.
write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld') write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld')
write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld')
else: else:
# Set up which gcc nvcc should use as the host compiler # Set up which gcc nvcc should use as the host compiler
# No need to set this on Windows # No need to set this on Windows
@ -1506,7 +1508,6 @@ def main():
set_tf_download_clang(environ_cp) set_tf_download_clang(environ_cp)
if environ_cp.get('TF_DOWNLOAD_CLANG') == '1': if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
write_to_bazelrc('build --config=download_clang') write_to_bazelrc('build --config=download_clang')
write_to_bazelrc('test --config=download_clang')
# SYCL / ROCm / CUDA are mutually exclusive. # SYCL / ROCm / CUDA are mutually exclusive.
# At most 1 GPU platform can be configured. # At most 1 GPU platform can be configured.
@ -1546,12 +1547,6 @@ def main():
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False) set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
if environ_cp.get('TF_CONFIGURE_IOS') == '1': if environ_cp.get('TF_CONFIGURE_IOS') == '1':
configure_ios() configure_ios()
else:
# TODO(pcloudy): Remove BAZEL_USE_CPP_ONLY_TOOLCHAIN after Bazel is upgraded
# to 0.24.0.
# For working around https://github.com/bazelbuild/bazel/issues/7607
if is_macos():
write_to_bazelrc('build --action_env=BAZEL_USE_CPP_ONLY_TOOLCHAIN=1')
print('Preconfigured Bazel build configs. You can use any of the below by ' print('Preconfigured Bazel build configs. You can use any of the below by '
'adding "--config=<>" to your build command. See .bazelrc for more ' 'adding "--config=<>" to your build command. See .bazelrc for more '

View File

@ -184,6 +184,12 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "linux_aarch64",
values = {"cpu": "aarch64"},
visibility = ["//visibility:public"],
)
config_setting( config_setting(
name = "linux_x86_64", name = "linux_x86_64",
values = {"cpu": "k8"}, values = {"cpu": "k8"},
@ -420,6 +426,9 @@ config_setting(
values = {"cpu": "x64_windows"}, values = {"cpu": "x64_windows"},
) )
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group( package_group(
name = "internal", name = "internal",
packages = [ packages = [

View File

@ -32,10 +32,13 @@ from __future__ import print_function as _print_function
import distutils as _distutils import distutils as _distutils
import inspect as _inspect import inspect as _inspect
import logging as _logging
import os as _os import os as _os
import site as _site import site as _site
import sys as _sys import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
# Make sure directory containing top level submodules is in # Make sure directory containing top level submodules is in
@ -49,25 +52,29 @@ if not hasattr(_current_module, '__path__'):
elif _tf_api_dir not in __path__: elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir) __path__.append(_tf_api_dir)
# pylint: disable=g-bad-import-order # Hook external TensorFlow modules.
from tensorflow.python.tools import component_api_helper as _component_api_helper try:
_component_api_helper.package_hook( from tensorboard.summary._tf import summary
parent_package_str=__name__, _current_module.__path__ = (
child_package_str=('tensorboard.summary._tf.summary'), [_module_util.get_parent_dir(summary)] + _current_module.__path__)
error_msg="Limited tf.summary API due to missing TensorBoard installation") except ImportError:
_component_api_helper.package_hook( _logging.warning(
parent_package_str=__name__, "Limited tf.summary API due to missing TensorBoard installation.")
child_package_str=(
'tensorflow_estimator.python.estimator.api._v2.estimator')) try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
if not hasattr(_current_module, 'estimator'):
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v2.keras'))
# Enable TF2 behaviors # Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top

View File

@ -26,24 +26,37 @@ import sys as _sys
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper # Make sure directory containing top level submodules is in
_component_api_helper.package_hook( # the __path__ so that "from tensorflow.foo import bar" works.
parent_package_str=__name__, # We're using bitwise, but there's nothing special about that.
child_package_str=( _API_MODULE = bitwise # pylint: disable=undefined-variable
'tensorflow_estimator.python.estimator.api._v1.estimator'))
_current_module = _sys.modules[__name__] _current_module = _sys.modules[__name__]
if not hasattr(_current_module, 'estimator'): _tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
_component_api_helper.package_hook( if not hasattr(_current_module, '__path__'):
parent_package_str=__name__, __path__ = [_tf_api_dir]
child_package_str=( elif _tf_api_dir not in __path__:
'tensorflow_estimator.python.estimator.api.estimator')) __path__.append(_tf_api_dir)
_component_api_helper.package_hook(
parent_package_str=__name__, # Hook external TensorFlow modules.
child_package_str=('tensorflow.python.keras.api._v1.keras')) try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """ _CONTRIB_WARNING = """
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
@ -66,17 +79,6 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
# The 'app' module will be imported as part of the placeholder section above. # The 'app' module will be imported as part of the placeholder section above.
app.flags = flags # pylint: disable=undefined-variable app.flags = flags # pylint: disable=undefined-variable
# Also use 'app' module (choice is arbitrary) to derive the API directory below.
_API_MODULE = app # pylint: disable=undefined-variable
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Load all plugin libraries from site-packages/tensorflow-plugins if we are # Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip. # running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin # TODO(gunan): Enable setting an environment variable to define arbitrary plugin

View File

@ -104,6 +104,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:android_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
"@com_google_absl//absl/strings",
"//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients", "//tensorflow/cc:gradients",
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
@ -145,6 +146,7 @@ tf_cuda_library(
"//tensorflow/core:lib_platform", "//tensorflow/core:lib_platform",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -243,6 +245,28 @@ tf_cuda_library(
}), }),
) )
tf_cuda_library(
name = "ops",
srcs = [
"ops.cc",
],
hdrs = [
"ops.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":tf_status_helper",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}) + [":c_api_internal"],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Tests # Tests
@ -291,7 +315,6 @@ tf_cuda_cc_test(
"//conditions:default": [], "//conditions:default": [],
}), }),
tags = [ tags = [
"no_oss", # http://b/119522529
"noasan", "noasan",
], ],
# We must ensure that the dependencies can be dynamically linked since # We must ensure that the dependencies can be dynamically linked since
@ -445,6 +468,27 @@ tf_cuda_cc_test(
], ],
) )
tf_cc_test(
name = "ops_test",
size = "small",
srcs = ["ops_test.cc"],
linkopts = select({
"//conditions:default": [],
}),
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Python API target # Python API target

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/framework/logging.h"
#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/kernels/logging_ops.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
@ -66,6 +67,24 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
} }
} }
unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
tensorflow::BuildXlaOpsPassFlags* flags =
tensorflow::GetBuildXlaOpsPassFlags();
bool original = flags->tf_xla_enable_lazy_compilation;
flags->tf_xla_enable_lazy_compilation = enable;
return original;
}
void TF_SetXLaAutoJitMode(const char* mode) {
tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
}
void TF_SetXlaMinClusterSize(int size) {
tensorflow::MarkForCompilationPassFlags* flags =
tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_min_cluster_size = size;
}
TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
unsigned char gpu_memory_allow_growth, unsigned char gpu_memory_allow_growth,
unsigned int num_cpu_devices) { unsigned int num_cpu_devices) {
@ -676,7 +695,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer( LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
std::move(server), grpc_server->worker_env()->device_mgr, std::move(server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr)); grpc_server->worker_env()->collective_executor_mgr));

View File

@ -62,6 +62,20 @@ extern "C" {
TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
unsigned char enable); unsigned char enable);
// Set XLA's internal BuildXlaOpsPassFlags.tf_xla_enable_lazy_compilation to the
// value of 'enabled'. Also returns the original value of that flag.
//
// Use in tests to allow XLA to fallback to TF classic. This has global effect.
TF_CAPI_EXPORT unsigned char TF_SetXlaEnableLazyCompilation(
unsigned char enable);
// Sets XLA's auto jit mode according to the specified string, which is parsed
// as if passed in XLA_FLAGS. This has global effect.
TF_CAPI_EXPORT void TF_SetXLaAutoJitMode(const char* mode);
// Sets XLA's minimum cluster size. This has global effect.
TF_CAPI_EXPORT void TF_SetXlaMinClusterSize(int size);
// Create a serialized tensorflow.ConfigProto proto, where: // Create a serialized tensorflow.ConfigProto proto, where:
// //
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if // a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/c_api_internal.h"
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
@ -295,7 +295,8 @@ Status FillFunctionBody(
} }
// Graph to FunctionDef conversion. This code is closely modeled on the Python // Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in tensorflow/python/framework/function.py. // function graph_to_function_def(), which is located in
// tensorflow/python/framework/graph_to_function_def.py.
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
bool append_hash_to_fn_name, bool append_hash_to_fn_name,
const std::vector<const Node*>& body_nodes, const std::vector<const Node*>& body_nodes,
@ -352,6 +353,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
argdef->set_type(node->output_type(idx)); argdef->set_type(node->output_type(idx));
const string& input_name = node_names.GetInputName(node->name()); const string& input_name = node_names.GetInputName(node->name());
argdef->set_name(input_name); argdef->set_name(input_name);
auto& arg_attrs = (*fdef->mutable_arg_attr())[i];
for (const auto& attr : node->attrs()) {
// Only copy internal attributes. These attributes will be applied to
// _Arg/Placeholder nodes when this FunctionDef is converted to graph, and
// normal attributes for nodes cannot be applied to those _Arg/Placeholder
// nodes.
if (absl::StartsWith(attr.first, "_")) {
arg_attrs.mutable_attr()->insert(attr);
}
}
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
} }

View File

@ -1278,6 +1278,46 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int"); EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
} }
void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
const char* attr_name, const char* attr_value,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
TF_NewGraph(), TF_DeleteGraph);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
TF_DeleteStatus);
TF_Operation* node;
NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
&node);
TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
ASSERT_NE(func_, nullptr);
// Verify that FunctionDef ArgDef has attributes.
ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
auto arg_attrs = func_->fdef.arg_attr().find(0);
ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
auto iter = arg_attrs->second.attr().find("_test_attr");
ASSERT_NE(iter, arg_attrs->second.attr().end());
EXPECT_EQ(iter->second.s(), "value");
}
TEST_F(CApiFunctionTest, SetGradientAndRun) { TEST_F(CApiFunctionTest, SetGradientAndRun) {
// Define the function and its grad // Define the function and its grad
DefineFunction(func_name_, &func_); DefineFunction(func_name_, &func_);

View File

@ -24,8 +24,10 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM // Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h" // NO_LINT #include "tensorflow/core/platform/platform.h"
// clang-format on
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/framework/op_gen_lib.h"

View File

@ -29,8 +29,7 @@ namespace checkpoint {
class TensorSliceReader; class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename, CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
TF_Status* out_status)
: reader_(nullptr), : reader_(nullptr),
v2_reader_(nullptr), v2_reader_(nullptr),
var_to_shape_map_(nullptr), var_to_shape_map_(nullptr),
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
v2_reader_.reset( v2_reader_.reset(
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
if (!v2_reader_->status().ok()) { if (!v2_reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, v2_reader_->status()); Set_TF_Status_from_Status(status, v2_reader_->status());
return; return;
} }
auto result = BuildV2VarMaps(); auto result = BuildV2VarMaps();
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
} else { } else {
reader_.reset(new TensorSliceReader(filename)); reader_.reset(new TensorSliceReader(filename));
if (!reader_->status().ok()) { if (!reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, reader_->status()); Set_TF_Status_from_Status(status, reader_->status());
return; return;
} }
var_to_shape_map_.reset( var_to_shape_map_.reset(

View File

@ -39,7 +39,7 @@ class TensorSliceReader;
// variables. // variables.
class CheckpointReader { class CheckpointReader {
public: public:
CheckpointReader(const string& filepattern, TF_Status* out_status); CheckpointReader(const string& filename, TF_Status* status);
bool HasTensor(const string& name) const; bool HasTensor(const string& name) const;
const string DebugString() const; const string DebugString() const;

View File

@ -1,4 +1,5 @@
# Experimental extensions to the C API for eager execution of kernels. # Experimental extensions to the C API for eager execution of kernels.
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load( load(
@ -258,3 +259,22 @@ filegroup(
srcs = ["c_api.h"], srcs = ["c_api.h"],
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
) )
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite)
filegroup(
name = "srcs",
srcs = glob(
[
"*.cc",
"*.h",
],
exclude = [
"c_api_experimental.cc",
"c_api_experimental.h",
"*test*",
],
),
visibility = ["//visibility:public"],
)

75
tensorflow/c/eager/c_api.cc Executable file → Normal file
View File

@ -21,6 +21,11 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
@ -38,11 +43,15 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h"
@ -88,6 +97,7 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name(); return (d == nullptr) ? "cpu:0" : d->name();
} }
#if !defined(IS_MOBILE_PLATFORM)
tensorflow::Status GetAllRemoteDevices( tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers, const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache, tensorflow::WorkerCacheInterface* worker_cache,
@ -220,7 +230,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, keep_alive_secs, server_def, remote_workers, rendezvous_id, keep_alive_secs, server_def,
remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); remote_eager_workers.get(), ctx->context->Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r = tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
@ -239,12 +249,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr; auto* device_mgr = grpc_server->worker_env()->device_mgr;
return ctx->context.InitializeRemote( return ctx->context->InitializeRemote(
std::move(server), std::move(remote_eager_workers), std::move(server), std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts, r, device_mgr, std::move(remote_device_mgr), remote_contexts, r, device_mgr,
keep_alive_secs); keep_alive_secs);
#undef LOG_AND_RETURN_IF_ERROR #undef LOG_AND_RETURN_IF_ERROR
} }
#endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op, tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) { TFE_TensorHandle* input) {
@ -341,7 +352,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char enable, unsigned char enable,
TF_Status* status) { TF_Status* status) {
status->status = ctx->context.SetAsyncForThread(enable); status->status = ctx->context->SetAsyncForThread(enable);
} }
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
@ -381,16 +392,14 @@ void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList; TF_DeviceList* list = new TF_DeviceList;
ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context.remote_device_mgr()) { if (ctx->context->remote_device_mgr()) {
ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
} }
return list; return list;
} }
void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) { void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
status->status = ctx->context.ClearCaches();
}
// Set server_def on the context, possibly updating it. // Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
@ -398,6 +407,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
const void* proto, const void* proto,
size_t proto_len, size_t proto_len,
TF_Status* status) { TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::errors::Unimplemented(
"TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def; tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) { if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -406,11 +419,12 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
} }
status->status = status->status =
UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
#endif // !IS_MOBILE_PLATFORM
} }
void TFE_ContextSetThreadLocalDevicePlacementPolicy( void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy( ctx->context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
} }
@ -420,19 +434,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
return static_cast<TFE_ContextDevicePlacementPolicy>( return static_cast<TFE_ContextDevicePlacementPolicy>(
ctx->context.GetDevicePlacementPolicy()); ctx->context->GetDevicePlacementPolicy());
} }
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
status->status = ctx->context.AsyncWait(); status->status = ctx->context->AsyncWait();
} }
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
status->status = ctx->context.GetStatus(); status->status = ctx->context->GetStatus();
} }
void TFE_ContextAsyncClearError(TFE_Context* ctx) { void TFE_ContextAsyncClearError(TFE_Context* ctx) {
ctx->context.ClearAsyncError(); ctx->context->ClearAsyncError();
} }
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
@ -592,7 +606,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
return new TFE_Op(ctx, name, false, types, return new TFE_Op(ctx, name, false, types,
new TFE_OpInferenceContext(op_def)); new TFE_OpInferenceContext(op_def));
} }
if (!ctx->context.FindFunctionByName(name)) { if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound( status->status = tensorflow::errors::NotFound(
"'", name, "'", name,
"' is neither a type of a primitive operation nor a name " "' is neither a type of a primitive operation nor a name "
@ -890,7 +904,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
const char* device_name, const char* device_name,
TF_Status* status) { TF_Status* status) {
tensorflow::TensorHandle* handle; tensorflow::TensorHandle* handle;
status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
device_name, &handle); device_name, &handle);
if (status->status.ok()) { if (status->status.ok()) {
return new TFE_TensorHandle(handle); return new TFE_TensorHandle(handle);
@ -907,26 +921,31 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return; return;
} }
status->status = ctx->context.AddFunctionDef(function_def); status->status = ctx->context->AddFunctionDef(function_def);
} }
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) { TF_Status* status) {
status->status = ctx->context.AddFunctionDef(function->fdef); status->status = ctx->context->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
status->status = ctx->context->RemoveFunction(name);
} }
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
return ctx->context.FindFunctionDef(name) != nullptr; return ctx->context->FindFunctionDef(name) != nullptr;
} }
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(true); ctx->context->SetShouldStoreGraphs(true);
ctx->context.SetShouldStoreStepStats(true); ctx->context->SetShouldStoreStepStats(true);
} }
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(false); ctx->context->SetShouldStoreGraphs(false);
ctx->context.SetShouldStoreStepStats(false); ctx->context->SetShouldStoreStepStats(false);
} }
} // extern "C" } // extern "C"
@ -955,9 +974,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
TFE_ContextAsyncWait(ctx, status); TFE_ContextAsyncWait(ctx, status);
if (!status->status.ok()) return; if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
ctx->context.ClearRunMetadata(); ctx->context->ClearRunMetadata();
} }
namespace { namespace {
@ -973,9 +992,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
} }
} // namespace } // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
namespace tensorflow { namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,

View File

@ -98,8 +98,7 @@ TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
// Clears the internal caches in the TFE context. Useful when reseeding random // Clears the internal caches in the TFE context. Useful when reseeding random
// ops. // ops.
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
TF_Status* status);
// Sets a thread-local device placement policy. After this call, other calls to // Sets a thread-local device placement policy. After this call, other calls to
// TFE_Execute in the same thread will use the device policy specified here // TFE_Execute in the same thread will use the device policy specified here
@ -411,6 +410,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
TF_Function* function, TF_Function* function,
TF_Status* status); TF_Status* status);
// Removes a function from the context. Once removed, you can no longer
// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any
// other function which calls it as an attribute.
TF_CAPI_EXPORT extern void TFE_ContextRemoveFunction(TFE_Context* ctx,
const char* name,
TF_Status* status);
// Checks whether a function is registered under `name`. // Checks whether a function is registered under `name`.
TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx, TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx,
const char* name); const char* name);

View File

@ -63,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() {
void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
TFE_Context* eager_context) { TFE_Context* eager_context) {
profiler_context->profiler_context.eager_context = &eager_context->context; profiler_context->profiler_context.eager_context = eager_context->context;
} }
void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
@ -77,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) {
} }
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(true); ctx->context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(false); ctx->context->SetShouldStoreGraphs(false);
} }
bool TFE_ProfilerClientStartTracing(const char* service_addr, bool TFE_ProfilerClientStartTracing(const char* service_addr,
@ -99,59 +99,6 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
return s.ok(); return s.ok();
} }
static tensorflow::mutex gauges_map_lock(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<string,
tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
get_gauges_map() EXCLUSIVE_LOCKS_REQUIRED(gauges_map_lock) {
static std::unordered_map<
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
gauges_map = new std::unordered_map<
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>;
return gauges_map;
}
static tensorflow::mutex samplers_map_lock(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
get_samplers_map() EXCLUSIVE_LOCKS_REQUIRED(samplers_map_lock) {
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
samplers_map =
new std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>;
return samplers_map;
}
void TFE_MonitoringSetGauge(const char* name, const char* label,
int64_t value) {
tensorflow::mutex_lock l(gauges_map_lock);
auto gauges_map = get_gauges_map();
if (gauges_map->find(name) == gauges_map->end()) {
gauges_map->emplace(
name, tensorflow::monitoring::Gauge<tensorflow::int64, 1>::New(
name,
tensorflow::strings::StrCat(
name, " :Gauge metric collected from Python API."),
"metric_descriptor"));
}
gauges_map->at(name)->GetCell(label)->Set(value);
}
void TFE_MonitoringAddSampler(const char* name, const char* label,
double value) {
tensorflow::mutex_lock l(samplers_map_lock);
auto samplers_map = get_samplers_map();
if (samplers_map->find(name) == samplers_map->end()) {
samplers_map->emplace(
name, tensorflow::monitoring::Sampler<1>::New(
{name,
tensorflow::strings::StrCat(
name, " :Counter metric collected from Python API."),
"metric_descriptor"},
{tensorflow::monitoring::Buckets::Exponential(1, 2, 30)}));
}
samplers_map->at(name)->GetCell(label)->Add(value);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) { int64_t value) {
cell->cell.IncrementBy(value); cell->cell.IncrementBy(value);
@ -166,6 +113,10 @@ TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
const char* description) { const char* description) {
auto* result = new TFE_MonitoringCounter0({name, description}); auto* result = new TFE_MonitoringCounter0({name, description});
Set_TF_Status_from_Status(status, result->counter->GetStatus()); Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result; return result;
} }
@ -185,6 +136,10 @@ TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
const char* label1) { const char* label1) {
auto* result = new TFE_MonitoringCounter1({name, description, label1}); auto* result = new TFE_MonitoringCounter1({name, description, label1});
Set_TF_Status_from_Status(status, result->counter->GetStatus()); Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result; return result;
} }
@ -206,6 +161,10 @@ TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
auto* result = auto* result =
new TFE_MonitoringCounter2({name, description, label1, label2}); new TFE_MonitoringCounter2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->counter->GetStatus()); Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result; return result;
} }
@ -218,3 +177,344 @@ TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
return static_cast<TFE_MonitoringCounterCell*>( return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1, label2))); static_cast<void*>(counter->counter->GetCell(label1, label2)));
} }
void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
int64_t value) {
cell->cell.Set(value);
}
int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringIntGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
TFE_MonitoringIntGauge0* gauge) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
TFE_MonitoringIntGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringIntGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
const char* value) {
cell->cell.Set({value});
}
const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
tensorflow::string value = cell->cell.value();
void* data = tensorflow::port::Malloc(value.length());
value.copy(static_cast<char*>(data), value.length(), 0);
buf->data = data;
buf->length = value.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* status, const char* description) {
auto* result = new TFE_MonitoringStringGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
TFE_MonitoringStringGauge0* gauge) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* status, const char* description,
const char* label1) {
auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
TFE_MonitoringStringGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2) {
auto* result =
new TFE_MonitoringStringGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
bool value) {
cell->cell.Set(value);
}
bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringBoolGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
TFE_MonitoringBoolGauge0* gauge) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
TFE_MonitoringBoolGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringBoolGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
double value) {
cell->cell.Add(value);
}
void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
TF_Buffer* buf) {
string content;
cell->cell.value().SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
buf->data = data;
buf->length = content.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
double growth_factor,
int bucket_count) {
return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
bucket_count);
});
}
void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
delete buckets;
}
TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringSampler0(
{name, buckets->create_buckets(), description});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell()));
}
TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1) {
auto* result = new TFE_MonitoringSampler1(
{name, buckets->create_buckets(), description, label1});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1)));
}
TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1, const char* label2) {
auto* result = new TFE_MonitoringSampler2(
{name, buckets->create_buckets(), description, label1, label2});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
}

View File

@ -87,19 +87,7 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list, const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts); bool include_dataset_ops, int duration_ms, int num_tracing_attempts);
// Set the value of a Gauge metric. If the metric with given name does not // TODO(fishx): Move these monitoring APIs into a separate file.
// exist, it will create a new Gauge metric. Right now it only supports type
// int64, consider to add more type supports if needed.
TF_CAPI_EXPORT extern void TFE_MonitoringSetGauge(const char* name,
const char* label,
int64_t value);
// Add the given value to a Sampler metric. If the metric with given name
// does not exist, it will create a new Sampler metric.
TF_CAPI_EXPORT extern void TFE_MonitoringAddSampler(const char* name,
const char* label,
double value);
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Monitoring Counter APIs. // Monitoring Counter APIs.
// These APIs de-templated monitoring Counter for swig. // These APIs de-templated monitoring Counter for swig.
@ -149,6 +137,179 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2( TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2); TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Gauge APIs.
// These APIs de-templated monitoring Gauge for swig.
typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell;
// Atomically set the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet(
TFE_MonitoringIntGaugeCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue(
TFE_MonitoringIntGaugeCell* cell);
// APIs for Int Gauge without label.
typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0(
TFE_MonitoringIntGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge);
// APIs for Int Gauge with 1 label.
typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1(
TFE_MonitoringIntGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge,
const char* label1);
// APIs for Int Gauge with 2 label.
typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2(
TFE_MonitoringIntGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet(
TFE_MonitoringStringGaugeCell* cell, const char* value);
// Retrieves the string value and saves it in buffer.
TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf);
// APIs for String Gauge without label.
typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0(
TFE_MonitoringStringGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge);
// APIs for String Gauge with 1 label.
typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1(
TFE_MonitoringStringGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge,
const char* label1);
// APIs for String Gauge with 2 label.
typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2(
TFE_MonitoringStringGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet(
TFE_MonitoringBoolGaugeCell* cell, bool value);
TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue(
TFE_MonitoringBoolGaugeCell* cell);
// APIs for Bool Gauge without label.
typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0(
TFE_MonitoringBoolGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge);
// APIs for Bool Gauge with 1 label.
typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1(
TFE_MonitoringBoolGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge,
const char* label1);
// APIs for Bool Gauge with 2 label.
typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2(
TFE_MonitoringBoolGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge,
const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Sampler APIs.
// These APIs de-templated monitoring Sampler for swig.
typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell;
// Atomically add the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd(
TFE_MonitoringSamplerCell* cell, double value);
// Retrieves the current value of the cell. The return value is a HistogramProto
// saved in buffer.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue(
TFE_MonitoringSamplerCell* cell, TF_Buffer* buf);
// APIs for sampler buckets
typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets;
TF_CAPI_EXPORT extern TFE_MonitoringBuckets*
TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor,
int bucket_count);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets(
TFE_MonitoringBuckets* buckets);
// APIs for Sampler without label.
typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0;
TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0(
TFE_MonitoringSampler0* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler);
// APIs for Sampler with 1 label.
typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1;
TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1(
TFE_MonitoringSampler1* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1);
// APIs for Sampler with 2 label.
typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2;
TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
TFE_MonitoringSampler2* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -131,23 +131,6 @@ TEST(CAPI, MultipleProfilerSession) {
TFE_DeleteProfilerContext(profiler_context); TFE_DeleteProfilerContext(profiler_context);
} }
TEST(CAPI, MonitoringSetGauge) {
TFE_MonitoringSetGauge("test/gauge", "label", 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
EXPECT_EQ(1,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringSetGauge("test/gauge", "label", 5);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(5,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
}
TEST(CAPI, MonitoringCounter0) { TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
auto* counter = auto* counter =
@ -200,8 +183,59 @@ TEST(CAPI, MonitoringCounterMultiple) {
TFE_MonitoringDeleteCounter2(counter2); TFE_MonitoringDeleteCounter2(counter2);
} }
TEST(CAPI, MonitoringAddSampler) { TEST(CAPI, MonitoringGauge0) {
TFE_MonitoringAddSampler("test/sampler", "label", 1.0); TF_Status* status = TF_NewStatus();
auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellIntGauge0(gauge);
TFE_MonitoringIntGaugeCellSet(cell, 1);
EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
EXPECT_EQ(1,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringIntGaugeCellSet(cell, 5);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(5,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleGauge) {
TF_Status* status = TF_NewStatus();
auto* gauge1 =
TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo");
TFE_MonitoringBoolGaugeCellSet(cell1, true);
EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1));
auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test",
"label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar");
TFE_MonitoringStringGaugeCellSet(cell2, "str");
auto* buf = new TF_Buffer;
TFE_MonitoringStringGaugeCellValue(cell2, buf);
string data(static_cast<const char*>(buf->data), buf->length);
delete buf;
EXPECT_EQ(data, "str");
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringSampler0) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler =
TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellSampler0(sampler);
TFE_MonitoringSamplerCellAdd(cell, 1.0);
auto* collection_registry = monitoring::CollectionRegistry::Default(); auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options; monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics = std::unique_ptr<monitoring::CollectedMetrics> metrics =
@ -213,11 +247,48 @@ TEST(CAPI, MonitoringAddSampler) {
->points.at(0) ->points.at(0)
->histogram_value.sum()); ->histogram_value.sum());
TFE_MonitoringAddSampler("test/sampler", "label", 5.0); TFE_MonitoringSamplerCellAdd(cell, 5.0);
metrics = collection_registry->CollectMetrics(options); metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler") EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
->points.at(0) ->points.at(0)
->histogram_value.sum()); ->histogram_value.sum());
TFE_MonitoringDeleteBuckets(buckets);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleSampler) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status,
"test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo");
TFE_MonitoringSamplerCellAdd(cell1, 1.0);
TFE_MonitoringSamplerCellAdd(cell1, 2.0);
TF_Buffer* result1 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell1, result1);
tensorflow::HistogramProto hitogram1;
EXPECT_TRUE(hitogram1.ParseFromString(
{reinterpret_cast<const char*>(result1->data), result1->length}));
EXPECT_EQ(hitogram1.sum(), 3.0);
delete result1;
auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status,
"test", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar");
TFE_MonitoringSamplerCellAdd(cell2, 2.0);
TFE_MonitoringSamplerCellAdd(cell2, 3.0);
TF_Buffer* result2 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell2, result2);
tensorflow::HistogramProto hitogram2;
EXPECT_TRUE(hitogram2.ParseFromString(
{reinterpret_cast<const char*>(result2->data), result2->length}));
EXPECT_EQ(hitogram2.sum(), 5.0);
delete result2;
TFE_MonitoringDeleteBuckets(buckets);
TF_DeleteStatus(status);
} }
} // namespace } // namespace

View File

@ -36,20 +36,14 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/lib/profiler_session.h"
@ -68,13 +62,16 @@ struct TFE_Context {
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous, tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator) const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(opts, : context(new tensorflow::EagerContext(
static_cast<tensorflow::ContextDevicePlacementPolicy>( opts,
default_policy), static_cast<tensorflow::ContextDevicePlacementPolicy>(
async, device_mgr, device_mgr_owned, rendezvous, default_policy),
custom_kernel_creator) {} async, device_mgr, device_mgr_owned, rendezvous,
custom_kernel_creator)) {}
tensorflow::EagerContext context; ~TFE_Context() { context->Unref(); }
tensorflow::EagerContext* context;
}; };
struct TFE_TensorHandle { struct TFE_TensorHandle {
@ -114,7 +111,7 @@ struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function, TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t, const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx) TFE_OpInferenceContext* inference_ctx)
: operation(&ctx->context, op, is_function, t), : operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {} inference_ctx(inference_ctx) {}
tensorflow::EagerOperation operation; tensorflow::EagerOperation operation;
@ -159,6 +156,98 @@ struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter; using TFE_MonitoringCounter::TFE_MonitoringCounter;
}; };
struct TFE_MonitoringIntGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
};
struct TFE_MonitoringStringGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
};
struct TFE_MonitoringBoolGaugeCell {
tensorflow::monitoring::GaugeCell<bool> cell;
};
template <typename ValueType, int NumLabels>
struct TFE_MonitoringGauge {
template <typename... LabelDesc>
TFE_MonitoringGauge(const char* name, const char* description,
LabelDesc&&... label) {
gauge = absl::WrapUnique(
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
};
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBuckets {
TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
}
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
create_buckets;
};
struct TFE_MonitoringSamplerCell {
tensorflow::monitoring::SamplerCell cell;
};
template <int NumLabels>
struct TFE_MonitoringSampler {
template <typename... LabelDesc>
TFE_MonitoringSampler(
const char* name,
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
const char* description, LabelDesc&&... label) {
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
{name, description, label...}, std::move(buckets)));
}
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
};
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
namespace tensorflow { namespace tensorflow {
// Set an AttrValue on the op. Doesn't handle the list types. // Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,

View File

@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include <string.h> #include <string.h>
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true); TestRemoteExecuteSilentCopies(true);
} }
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContext(ctx);
// Delete tensors after context is deleted.
TFE_DeleteTensorHandle(h0_task1);
TF_DeleteStatus(status);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
TestRemoteExecuteDeleteTensorAfterContext(false);
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
TestRemoteExecuteDeleteTensorAfterContext(true);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) { const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -1225,6 +1281,8 @@ TEST(CAPI, Function_ident_CPU) {
TF_DeleteTensor(r); TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]); TFE_DeleteTensorHandle(result[0]);
} }
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status); TF_DeleteStatus(status);
@ -1295,6 +1353,8 @@ TEST(CAPI, Function_ident_XLA_CPU) {
TF_DeleteTensor(r); TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]); TFE_DeleteTensorHandle(result[0]);
} }
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status); TF_DeleteStatus(status);
@ -1371,6 +1431,8 @@ void FunctionDefAndExecute(bool async) {
EXPECT_EQ(10, product[1]); EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]); EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]); EXPECT_EQ(22, product[3]);
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status); TF_DeleteStatus(status);
@ -1412,6 +1474,8 @@ void BM_ExecuteFunction(int iters, int async) {
tensorflow::testing::StopTiming(); tensorflow::testing::StopTiming();
TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval[0]); TFE_DeleteTensorHandle(retval[0]);
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status); TF_DeleteStatus(status);

View File

@ -0,0 +1,122 @@
# Description:
# Experimental C APIs for TensorFlow.
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
tf_cuda_library(
name = "rendezvous_internal",
srcs = [
"rendezvous.cc",
],
hdrs = [
"rendezvous.h",
"rendezvous_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "rendezvous",
hdrs = [
"rendezvous.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api",
],
)
tf_cuda_library(
name = "network_internal",
srcs = [
"network.cc",
],
hdrs = [
"network.h",
"network_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "network",
hdrs = [
"network.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":network_internal",
":rendezvous",
"//tensorflow/c:c_api",
],
)
# -----------------------------------------------------------------------------
# Tests
tf_cuda_cc_test(
name = "network_test",
size = "medium",
srcs = ["network_test.cc"],
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":network",
":network_internal",
":rendezvous",
":rendezvous_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:env",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session",
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

View File

@ -0,0 +1,166 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <memory>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
using tensorflow::ServerFactory;
namespace tensorflow {
/* static */ Status CGrpcServer::Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server) {
auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function,
join_function, delete_function);
GrpcServerOptions options;
options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) {
return new CRendezvousMgr(env, rendezvous_builder);
};
TF_RETURN_IF_ERROR(grpc_server->Init(options));
TF_Status* tf_status = TF_NewStatus();
grpc_server->SetContext(init_function(
reinterpret_cast<const TF_GrpcServer*>(grpc_server), tf_status));
TF_RETURN_IF_ERROR(tf_status->status);
TF_DeleteStatus(tf_status);
out_server->reset(grpc_server);
return Status::OK();
}
Status CGrpcServer::Start() {
Status status = GrpcServer::Start();
TF_Status* tf_status = TF_NewStatus();
(*start_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Stop() {
Status status = GrpcServer::Stop();
TF_Status* tf_status = TF_NewStatus();
(*stop_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Join() {
Status status = GrpcServer::Join();
TF_Status* tf_status = TF_NewStatus();
(*join_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
namespace {
// Factory that creates CGrpcServer instances.
class CServerFactory : public ServerFactory {
public:
CServerFactory(bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*,
TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder)
: accept_function_(accept_function),
init_function_(init_function),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,
join_function_, delete_function_, rendezvous_builder_, out_server));
return Status::OK();
}
// Returns true if and only if this factory can create a server
// based on the given `server_def`.
bool AcceptsOptions(const ServerDef& server_def) override {
return (*accept_function_)(server_def.protocol().c_str());
}
private:
bool (*accept_function_)(const char* protocol);
void* (*init_function_)(const TF_GrpcServer*, TF_Status*);
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
TF_RemoteRendezvousBuilder* rendezvous_builder_;
};
} // namespace
} // namespace tensorflow
// Server factory representation to use in C API.
// Holds CServerFactory pointer.
struct TF_GrpcServerFactory {
::tensorflow::CServerFactory* factory;
};
TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder) {
TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory;
server_factory->factory = new ::tensorflow::CServerFactory(
accept_function, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
return server_factory;
}
void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) {
DCHECK_NE(server_factory, nullptr);
delete server_factory;
}
void TF_RegisterGrpcServerFactory(const char* server_type,
TF_GrpcServerFactory* server_factory) {
ServerFactory::Register(server_type, server_factory->factory);
}

View File

@ -0,0 +1,97 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for TensorFlow Networking.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Users wishing to register a custom GrpcServer should call
// TF_NewServerFactory and then TF_RegisterGrpcServerFactory.
//
// Example:
// ```c++
// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
// rendezvous_init_function,
// receive_from_remote_async_function,
// rendezvous_delete_function);
//
// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
// accept_function,
// init_function,
// start_function,
// stop_function,
// join_function,
// delete_function,
// rendezvous_builder);
// TF_RegisterGrpcServerFactory("customfactory", factory);
// ...
// TF_DeleteGrpcServerFactory(factory);
// ```
typedef struct TF_GrpcServerFactory TF_GrpcServerFactory;
typedef struct TF_GrpcServerOptions TF_GrpcServerOptions;
typedef struct TF_GrpcServer TF_GrpcServer;
typedef struct TF_ServerContext {
TF_GrpcServer* const server;
void* context;
} TF_ServerContext;
// Creates a new TF_GrpcServerFactory instance. Caller takes ownership
// of TF_GrpcServerFactory instance and should deallocate it by calling
// TF_GrpcDeleteServerFactory.
// accept_function should return true if this ServerFactory can create
// server instances for the given protocol name (for e.g. grpc+verbs).
// GRPC servers created by this factory will call provided
// init_function, start_function, stop_function, join_function and
// delete_function.
//
// Note that clean shutdown is currently not implemented for GrpcServer.
// So, stop_function will never be called now but may be in the future
// when stop mechanism is supported.
TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Deletes TF_GrpcServerFactory instances.
// Note that this function only deletes TF_GrpcServerFactory wrapper.
// Actual underlying server factory would not be deleted and will
// remain registered.
TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory(
TF_GrpcServerFactory* server_factory);
// Registers provided server_factory for the given server_type.
// server_type must be unique to the server factory.
TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory(
const char* server_type, TF_GrpcServerFactory* server_factory);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
// GrpcServer implementation that forwards calls to callbacks.
class CGrpcServer : public GrpcServer {
protected:
CGrpcServer(const ServerDef& server_def,
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*))
: GrpcServer(server_def, ::tensorflow::Env::Default()),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
context_(nullptr) {}
public:
static Status Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server);
Status Start() override;
Status Stop() override;
Status Join() override;
~CGrpcServer() override { delete_function_(context_); }
protected:
void SetContext(void* context) { context_ = context; }
private:
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
void* context_;
friend class NetworksTest;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_

View File

@ -0,0 +1,256 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
bool accept_functionA(const char* protocol_name) {
return strcmp(protocol_name, "grpc+A") == 0;
}
bool accept_functionB(const char* protocol_name) {
return strcmp(protocol_name, "grpc+B") == 0;
}
struct SomeServerData {
bool server_started = false;
};
struct SomeRendezvousData {
int test = 0;
};
void* init_function(const TF_GrpcServer* server, TF_Status* status) {
SomeServerData* server_data = new SomeServerData();
TF_SetStatus(status, TF_OK, "");
return server_data;
}
void start_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
auto* server_data = static_cast<SomeServerData*>(context);
server_data->server_started = true;
TF_SetStatus(status, TF_OK, "");
}
void stop_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void join_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void delete_function(void* context) {
auto* server_data = static_cast<SomeServerData*>(context);
delete server_data;
}
void* rendezvous_init_function(void* server_context) {
return new SomeRendezvousData();
}
void Deallocator(void* data, size_t, void* arg) {
tensorflow::cpu_allocator()->DeallocateRaw(data);
*reinterpret_cast<bool*>(arg) = true;
}
void receive_from_remote_async_function(TF_ParsedKey* key,
TF_RendezvousArgs* args,
TF_RendezvousDoneCallback* callback,
void* context) {
// Create dummy tensor
const int num_bytes = 6 * sizeof(float);
float* values =
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
EIGEN_MAX_ALIGN_BYTES, num_bytes));
int64_t dims[] = {2, 3};
bool deallocator_called = false;
auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
&Deallocator, &deallocator_called);
callback->tensor = tensor;
auto* tf_status = TF_NewStatus();
TF_SetStatus(tf_status, TF_OK, "");
callback->status = tf_status;
TF_RendezvousDone(callback);
TF_DeleteStatus(tf_status);
TF_DeleteTensor(tensor);
}
void rendezvous_delete_function(void* context) {
auto* rendezvous_data = static_cast<SomeRendezvousData*>(context);
delete rendezvous_data;
}
tensorflow::ServerDef GetServerDef(const string& protocol,
const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol(protocol);
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
class NetworksTest : public ::testing::Test {
public:
~NetworksTest() override {}
SomeServerData* GetServerData(CGrpcServer* server) {
EXPECT_NE(server->context_, nullptr);
return static_cast<SomeServerData*>(server->context_);
}
};
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
const string& receiver, const string& name) {
Rendezvous::ParsedKey result;
CHECK(
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
name, FrameAndIter(0, 0)),
&result)
.ok());
return result;
}
void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def,
RemoteRendezvous* remote_rendezvous) {
int rendezvous_id = 0;
auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id);
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, *server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get()));
}
TEST_F(NetworksTest, TestStartServer) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionA, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryA", factory);
ServerDef server_def = GetServerDef("grpc+A", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
auto* server_data = GetServerData(grpc_server);
ASSERT_FALSE(server_data->server_started);
TF_EXPECT_OK(server->Start());
ASSERT_TRUE(server_data->server_started);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// TODO(annarev): find a clean way to shutdown server.
server.release();
}
TEST_F(NetworksTest, TestReceiveData) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionB, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryB", factory);
ServerDef server_def = GetServerDef("grpc+B", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
TF_EXPECT_OK(server->Start());
auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr;
auto* remote_rendezvous = rendezvous_mgr->Find(0);
auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1,
"/job:localhost/replica:0/task:0/device:CPU:0", "test");
Rendezvous::Args args;
bool done_callback_called = false;
auto* done_callback_called_ptr = &done_callback_called;
absl::Notification notification;
auto* notification_ptr = &notification;
InitializeRendezvous(grpc_server, &server_def, remote_rendezvous);
remote_rendezvous->RecvAsync(
key, args,
[done_callback_called_ptr, notification_ptr](
const Status&, const Rendezvous::Args&, const Rendezvous::Args&,
const Tensor&, const bool) mutable {
*done_callback_called_ptr = true;
notification_ptr->Notify();
});
notification.WaitForNotificationWithTimeout(absl::Seconds(10));
ASSERT_EQ(done_callback_called, true);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// Server doesn't have a clean shutdown.
server.release();
}
} // namespace tensorflow

View File

@ -0,0 +1,124 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/rendezvous.h"
#include <functional>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context),
void* server_context)
: BaseRemoteRendezvous(env, step_id),
receive_from_remote_async_function_(receive_from_remote_async_function),
delete_function_(delete_function),
context_(nullptr) {}
void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) {
TF_ParsedKey key;
key.src_device = parsed.src_device.data();
key.src_device_len = parsed.src_device.size();
key.dst_device = parsed.dst_device.data();
key.dst_device_len = parsed.dst_device.size();
key.full_key = parsed.FullKey().data();
key.full_key_len = parsed.FullKey().size();
TF_DeviceContext* device_context = new TF_DeviceContext();
device_context->context = args.device_context;
TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes();
alloc_attrs->value = args.alloc_attrs.value;
alloc_attrs->scope_id = args.alloc_attrs.scope_id;
alloc_attrs->on_host = args.alloc_attrs.on_host();
alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible();
TF_RendezvousArgs* cargs = new TF_RendezvousArgs();
cargs->device_context = device_context;
cargs->alloc_attrs = alloc_attrs;
TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback();
done_callback->done_callback = done;
done_callback->recv_args = cargs;
receive_from_remote_async_function_(&key, cargs, done_callback, context_);
}
CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); }
} // namespace tensorflow
TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context)) {
TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder();
builder->init_function = init_function;
builder->delete_function = delete_function;
builder->receive_from_remote_async_function =
receive_from_remote_async_function;
return builder;
}
void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder) {
DCHECK_NE(rendezvous_builder, nullptr);
delete rendezvous_builder;
}
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback) {
DCHECK_NE(callback, nullptr);
::tensorflow::Tensor tensor;
TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor));
::tensorflow::Rendezvous::Args recv_args;
recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value;
recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id;
recv_args.device_context = callback->recv_args->device_context->context;
::tensorflow::Rendezvous::Args sent_args;
callback->done_callback(callback->status->status, sent_args, recv_args,
tensor, callback->dead);
if (callback->recv_args) {
DCHECK_NE(callback->recv_args, nullptr);
DCHECK_NE(callback->recv_args->alloc_attrs, nullptr);
DCHECK_NE(callback->recv_args->device_context, nullptr);
delete callback->recv_args->alloc_attrs;
delete callback->recv_args->device_context;
delete callback->recv_args;
}
delete callback;
callback = nullptr;
}

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#include "tensorflow/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for Rendezvous.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Custom rendezvous allows for custom implementations of Recv call.
//
// Users wishing to create custom rendezvous objects should call
// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder
// to to TF_NewServerFactory.
typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder;
typedef struct TF_ParsedKey TF_ParsedKey;
typedef struct TF_RendezvousArgs TF_RendezvousArgs;
typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback;
// Creates a new TF_RemoteRendezvousBuilder instance.
// Rendezvous instances will forward calls to init_function,
// receive_from_remote_async_function and delete_function passed here.
//
// Note that receive_from_remote_async_function implementation must call
// TF_Done with the TF_DoneCallback passed as an argument.
TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context));
// Deletes TF_RemoteRendezvousBuilder instances.
TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Calls TF_DoneCallback and destroys callback instance and
// TF_DoneCallback members except `tensor` and `status`. Caller is
// responsible for deleting `tensor` and `status` after TF_Done returns.
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#include <stddef.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/platform/macros.h"
struct TF_ParsedKey {
// char* members might not be null-terminated.
const char* src_device;
size_t src_device_len;
const char* dst_device;
size_t dst_device_len;
const char* full_key;
size_t full_key_len;
};
struct TF_AllocatorAttributes {
bool on_host;
bool nic_compatible;
// NOTE: The upper 8 bits of the value are reserved for
// device-specific uses. Implementors of a device can interpret these
// upper 8 bits in device-specific ways, and ops implemented for those
// devices are responsible for setting those 8 bits appropriately.
tensorflow::uint32 value = 0;
// EXPERIMENTAL: If this is greater than zero, then allocation is delegated to
// a named special-purpose allocator on the same device.
tensorflow::int32 scope_id = 0;
};
struct TF_DeviceContext {
::tensorflow::DeviceContext* context;
};
struct TF_RendezvousArgs {
const TF_DeviceContext* device_context;
const TF_AllocatorAttributes* alloc_attrs;
};
struct TF_RendezvousDoneCallback {
::tensorflow::Rendezvous::DoneCallback done_callback;
// TODO(annarev): figure out if we should also support sent_args.
const TF_RendezvousArgs* recv_args;
TF_Tensor* tensor = nullptr;
TF_Status* status;
bool dead;
};
struct TF_RemoteRendezvousBuilder {
void* (*init_function)(void* server_context);
void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function)(void* context);
void* server_context;
};
namespace tensorflow {
class CRemoteRendezvous : public BaseRemoteRendezvous {
public:
CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*, void* context),
void (*delete_function)(void* context),
void* server_context);
void SetContext(void* context) { context_ = context; }
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) override;
private:
~CRemoteRendezvous() override;
void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function_)(void* context);
void* context_;
TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous);
};
class CRendezvousMgr : public BaseRendezvousMgr {
public:
CRendezvousMgr(const WorkerEnv* env,
const TF_RemoteRendezvousBuilder* rendezvous_builder)
: BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {}
protected:
BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env) override {
auto* rendezvous = new CRemoteRendezvous(
worker_env, step_id,
rendezvous_builder_->receive_from_remote_async_function,
rendezvous_builder_->delete_function,
rendezvous_builder_->server_context);
rendezvous->SetContext(rendezvous_builder_->init_function(
rendezvous_builder_->server_context));
return rendezvous;
}
private:
const TF_RemoteRendezvousBuilder* rendezvous_builder_;
TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr);
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_

326
tensorflow/c/ops.cc Normal file
View File

@ -0,0 +1,326 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
using ::tensorflow::DataType;
using ::tensorflow::OpDef;
using ::tensorflow::OpDeprecation;
using ::tensorflow::OpShapeInferenceFn;
using ::tensorflow::Set_TF_Status_from_Status;
using ::tensorflow::Status;
using ::tensorflow::shape_inference::DimensionHandle;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
typedef struct TF_OpDefinitionBuilder {
// The op definition proto representing the op.
tensorflow::OpDef op_def;
// The shape inference function, or nullptr if none is provided for this op.
OpShapeInferenceFn shape_inference_func;
} TF_OpDefinitionBuilder;
TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
auto* result = new TF_OpDefinitionBuilder;
result->op_def.set_name(op_name);
return result;
}
void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
delete builder;
}
static void PopulateArg(OpDef::ArgDef* arg, const char* name,
TF_DataType type) {
arg->set_name(name);
arg->set_type(static_cast<DataType>(type));
}
void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
const char* name, TF_DataType type) {
PopulateArg(builder->op_def.add_input_arg(), name, type);
}
void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
const char* name, TF_DataType type) {
PopulateArg(builder->op_def.add_output_arg(), name, type);
}
#define DEFINE_BUILDER_BOOL_SETTER(func_name, builder_setter_name, arg_name) \
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
bool arg_name) { \
builder->op_def.builder_setter_name(arg_name); \
}
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative, set_is_commutative, is_commutative)
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate, set_is_aggregate, is_aggregate)
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful, set_is_stateful, is_stateful)
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput,
set_allows_uninitialized_input,
allows_unintialized_input)
static OpDef::AttrDef* AddAttribute(TF_OpDefinitionBuilder* builder,
const char* name, const char* type_name) {
OpDef::AttrDef* attr = builder->op_def.add_attr();
attr->set_name(name);
attr->set_type(type_name);
return attr;
}
#define DEFINE_ATTR_SETTER(attr_type, type_name, field_c_type, field_name) \
void TF_OpDefinitionBuilderAdd##attr_type##Attr( \
TF_OpDefinitionBuilder* builder, const char* name) { \
AddAttribute(builder, name, type_name); \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##AttrWithDefaultValue( \
TF_OpDefinitionBuilder* builder, const char* name, \
field_c_type field_name) { \
OpDef::AttrDef* attr = AddAttribute(builder, name, type_name); \
attr->mutable_default_value()->set_##field_name(field_name); \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
TF_OpDefinitionBuilder* builder, const char* name, \
field_c_type field_name[], size_t n) { \
OpDef::AttrDef* attr = AddAttribute(builder, name, "list(" type_name ")"); \
for (int _i = 0; _i < n; ++_i) { \
attr->mutable_default_value()->mutable_list()->add_##field_name( \
field_name[_i]); \
} \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##ListAttr( \
TF_OpDefinitionBuilder* builder, const char* name) { \
TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
builder, name, NULL, 0); \
}
DEFINE_ATTR_SETTER(String, "string", const char*, s)
DEFINE_ATTR_SETTER(Int, "int", int64_t, i)
DEFINE_ATTR_SETTER(Float, "float", float, f)
DEFINE_ATTR_SETTER(Bool, "bool", bool, b)
void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
int version, const char* explanation) {
OpDeprecation* dep = builder->op_def.mutable_deprecation();
dep->set_version(version);
dep->set_explanation(explanation);
}
void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
::tensorflow::OpRegistry::Global()->Register(
[builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
op_reg_data->op_def.Clear();
op_reg_data->op_def.MergeFrom(builder->op_def);
op_reg_data->shape_inference_fn = builder->shape_inference_func;
return Status::OK();
});
// Calling ProcessRegistrations ensures that the cc_builder's finalize method
// is called and that the builder can be deleted.
Set_TF_Status_from_Status(
status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations());
delete builder;
}
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
TF_OpDefinitionBuilder* builder,
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
TF_Status* status)) {
builder->shape_inference_func =
[shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
TF_Status* c_status = TF_NewStatus();
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
shape_inference_func(c_ctx, c_status);
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
TF_DeleteStatus(c_status);
return result;
};
}
TF_ShapeHandle* TF_NewShapeHandle() {
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
}
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size) {
auto* handle = new ShapeHandle;
*handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
return reinterpret_cast<TF_ShapeHandle*>(handle);
}
void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* first,
TF_ShapeHandle* second,
TF_ShapeHandle* result,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
*reinterpret_cast<ShapeHandle*>(second),
reinterpret_cast<ShapeHandle*>(result));
Set_TF_Status_from_Status(status, s);
}
TF_DimensionHandle* TF_NewDimensionHandle() {
return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
}
int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
return cc_ctx->num_inputs();
}
void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
TF_ShapeHandle* handle,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
if (0 < i || i >= cc_ctx->num_inputs()) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
}
if (TF_GetCode(status) == TF_OK) {
auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
*cc_result = cc_ctx->input(i);
}
}
int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* handle) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
}
void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
TF_ShapeHandle* handle,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
if (0 < i || i >= cc_ctx->num_outputs()) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
}
if (TF_GetCode(status) == TF_OK) {
cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
}
}
void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
if (handle == nullptr) {
return;
}
delete reinterpret_cast<ShapeHandle*>(handle);
}
void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
if (handle == nullptr) {
return;
}
delete reinterpret_cast<DimensionHandle*>(handle);
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
void TF_ShapeInferenceContext_GetAttr##func( \
TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
cc_type v; \
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
Status s = cc_ctx->GetAttr(attr_name, &v); \
Set_TF_Status_from_Status(status, s); \
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
#define DEFINE_RANK_FUNC(func_name) \
void TF_ShapeInferenceContext##func_name( \
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
TF_ShapeHandle* result, TF_Status* status) { \
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \
auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \
Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \
Set_TF_Status_from_Status(status, s); \
}
DEFINE_RANK_FUNC(WithRank)
DEFINE_RANK_FUNC(WithRankAtLeast)
DEFINE_RANK_FUNC(WithRankAtMost)
int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* handle) {
return reinterpret_cast<InferenceContext*>(ctx)->Rank(
*reinterpret_cast<ShapeHandle*>(handle));
}
void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape_handle, int64_t i,
TF_DimensionHandle* result) {
int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
if (i < -rank || i >= rank) {
*cc_result = DimensionHandle();
return;
}
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
*cc_result = cc_ctx->Dim(*cc_shape_handle, i);
}
int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
return InferenceContext::ValueKnown(
*reinterpret_cast<DimensionHandle*>(dim_handle));
}
void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
Status s = ::tensorflow::shape_inference::UnknownShape(
reinterpret_cast<InferenceContext*>(ctx));
Set_TF_Status_from_Status(status, s);
}
void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape_handle,
int64_t start, int64_t end,
TF_ShapeHandle* result,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
start, end, cc_result);
Set_TF_Status_from_Status(status, s);
}
int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
return InferenceContext::Value(
*reinterpret_cast<DimensionHandle*>(dim_handle));
}

407
tensorflow/c/ops.h Normal file
View File

@ -0,0 +1,407 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Routines for registering new ops and for implementing op shape inference
// functions.
//
// This API is alpha software and is subject to change.
//
// REGISTRATION
// ------------
//
// In order to register a new op, create a new TF_OpDefinitionBuilder:
//
// TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("OpName");
//
// Inputs, outputs and attributes can be added to the builder with the
// corresponding functions, e.g.
//
// TF_OpDefinitionBuilderAddInput(builder, "input1: int32");
// TF_OpDefinitionBuilderAddOutput(builder, "output1: int64");
// TF_OpDefinitionBuilderAddAttr(builder, "attr: int32");
//
// The builder may then be registered with TensorFlow using the
// TF_RegisterOpDefinition function. E.g.
//
// TF_Status* status = TF_NewStatus();
// TF_RegisterOpDefinition(builder, &status);
// if (TF_GetCode(status) != TF_OK) {
// // handle error
// }
//
// SHAPE INFERENCE
// ---------------
//
// You can provide a shape inference function that TensorFlow will call when it
// wants to understand the shape of outputs that the op will produce. Use the
// TF_OpDefinitionBuilderSetShapeInferenceFunction function to register a shape
// inference function pointer with TensorFlow. The following is an example of a
// very simple shape inference function:
//
// void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
// TF_ShapeHandle* input = TF_NewShapeHandle();
// TF_ShapeInferenceContextGetInput(ctx, 0, input, status);
// if (TF_GetCode(status) == TF_OK) {
// TF_ShapeInferenceContextSetOutput(ctx, 0, input, status);
// }
// TF_DeleteShapeHandle(input);
// }
//
// The following code registers the inference function with TensorFlow:
//
// TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
//
// For more details about shape inference, see the documentation for
// TF_OpDefinitionBuilderSetShapeInferenceFunction.
#ifndef TENSORFLOW_C_OPS_H_
#define TENSORFLOW_C_OPS_H_
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include "tensorflow/c/c_api.h"
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
struct TF_DimensionHandle;
struct TF_OpDefinitionBuilder;
struct TF_ShapeHandle;
struct TF_ShapeInferenceContext;
// Returns a newly allocated op definition builder for the given op name. The
// returned builder may be customized with the `TF_OpDefinitionBuilder...`
// functions and then registered with TensorFlow with TF_RegisterOpDefinition.
//
// The returned pointer is either freed by a call to TF_RegisterOpDefinition, or
// can be manually deleted by TF_DeleteOpDefinitionBuilder if it is never
// registered.
TF_CAPI_EXPORT extern TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(
const char* op_name);
// Registers the given op builder with TensorFlow. Indicates success or
// otherwise in the given status.
//
// `builder` is freed whether the op was successfully registered or not. You
// must call either this function or TF_DeleteOpDefinitionBuilder to free the
// builder, but never both.
TF_CAPI_EXPORT extern void TF_RegisterOpDefinition(
TF_OpDefinitionBuilder* builder, TF_Status* status);
// Frees the given op definition builder. You must call either this function or
// TF_RegisterOpDefinition to free the builder, but never both.
TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder(
TF_OpDefinitionBuilder* builder);
//----------------------------------------------------
// Attribute functions.
// Adds a string attribute with the given name to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a string attribute with the given name and default value to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, const char* value);
// Adds a string list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a string list attribute with the given default values to the builder.
// `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddStringListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, const char* values[],
size_t n);
// Adds an integer attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds an integer attribute with the given name and default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, int64_t value);
// Adds an integer list attribute with the given name and no default value to
// the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds an integer list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddIntListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, int64_t values[],
size_t n);
// Adds a float attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a float attribute with the given name and default value to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, float value);
// Adds a float list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a float list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, float values[],
size_t n);
// Adds a boolean attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a boolean attribute with the given name and default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, bool value);
// Adds a boolean list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a boolean list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddBoolListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, bool values[], size_t n);
// Adds the input with the given name and type to the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput(
TF_OpDefinitionBuilder* builder, const char* name, TF_DataType type);
// Adds the output with the given name and type to the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput(
TF_OpDefinitionBuilder* builder, const char* output, TF_DataType type);
// Sets the commutative property for the op built by the given builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative(
TF_OpDefinitionBuilder* builder, bool is_commutative);
// Sets the is_aggregate property of the builder to the given value.
//
// If is_aggregate is true, then the operation produced by this builder accepts
// N >= 2 inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same shape as the
// input. The optimizer may replace an aggregate op taking input from multiple
// devices with a tree of aggregate ops that aggregate locally within each
// device (and possibly within groups of nearby devices) before communicating.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsAggregate(
TF_OpDefinitionBuilder* builder, bool is_aggregate);
// Sets the is_stateful property of the builder to the given value.
//
// The op built by this builder is stateful if its behavior depends on some
// state beyond its input tensors (e.g. variable reading op) or if it has a
// side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have no
// side-effects.
//
// By default Ops may be moved between devices. Stateful ops should either not
// be moved, or should only be moved if that state can also be moved (e.g. via
// some sort of save / restore). Stateful ops are guaranteed to never be
// optimized away by Common Subexpression Elimination (CSE).
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsStateful(
TF_OpDefinitionBuilder* builder, bool is_stateful);
// Sets the allows_uninitialized_input property of the operation built by this
// builder.
//
// By default, all inputs to an Op must be initialized Tensors. Ops that may
// initialize tensors for the first time should set this field to true, to allow
// the Op to take an uninitialized Tensor as input.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetAllowsUninitializedInput(
TF_OpDefinitionBuilder* builder, bool allows_uninitialized_input);
// Adds a deprecation warning for the given op. This indicates to the user that
// `version` is the first TensorFlow GraphDef version for which the operation is
// deprecated. `explanation` should contain the reason for the deprecation and
// what to use instead.
//
// This function is only an indicator that the operation may disappear in a
// version of TensorFlow after `version`. It does not affect op registration.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderDeprecated(
TF_OpDefinitionBuilder* builder, int version, const char* explanation);
// Sets the shape inference function for the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetShapeInferenceFunction(
TF_OpDefinitionBuilder* builder,
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
TF_Status* status));
//----------------------------------------------------
// Functions for TF_ShapeInferenceContext.
//
// Functions for implementing shape inference functions. TensorFlow uses these
// functions to determine the shape of tensors produced by an operation without
// having to actually run the operation. If an operation chooses to provide a
// shape inference function, it will be invoked by TensorFlow as needed.
//
// When invoked by TensorFlow, the shape inference function is provided with a
// TF_ShapeInferenceContext pointer. The function's implementation will use the
// accessor and mutator functions with names beginning with
// TF_ShapeInferenceContext to examine the input state and determine the output
// shape.
// Returns the number of inputs in the given shape inference context.
TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextNumInputs(
TF_ShapeInferenceContext* ctx);
// Returns a newly allocated shape handle. The shapes represented by these
// handles may be queried or mutated with the corresponding
// TF_ShapeInferenceContext... functions.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_NewShapeHandle();
// Places the ith input of the given shape inference context into the given
// shape handle, or returns a status other than TF_OK indicating why the input
// could not be retrieved
// (for example, if i < 0 || i >= TF_ShapeInferenceContextNumInputs(ctx)).
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextGetInput(
TF_ShapeInferenceContext* ctx, int i, TF_ShapeHandle* handle,
TF_Status* status);
// Places the given shape handle into the `i`th output position of the given
// context. Internally, the shape handle is copied; the caller may subsequently
// delete `handle`.
TF_CAPI_EXPORT
extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
int i, TF_ShapeHandle* handle,
TF_Status* status);
// Returns a newly-allocate shape handle representing a vector of the given
// size. The returned handle should be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size);
// Returns a newly allocated dimension handle. It must be freed with
// TF_DeleteDimensionHandle.
TF_CAPI_EXPORT extern TF_DimensionHandle* TF_NewDimensionHandle();
// Interprets the named shape inference context attribute as a TF_DataType and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// TF_DataType, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContext_GetAttrType(
TF_ShapeInferenceContext* ctx, const char* attr_name, TF_DataType* val,
TF_Status* status);
// Returns the rank of the shape represented by the given handle.
TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextRank(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
// Returns 1 if `handle` has a known rank, 0 otherwise.
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextRankKnown(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
// If <handle> has rank <rank>, or its rank is unknown, return OK and return the
// shape with asserted rank in <*result>. Otherwise an error is placed into
// `status`.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRank(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
TF_ShapeHandle* result, TF_Status* status);
// If <handle> has rank at least <rank>, or its rank is unknown, return OK and
// return the shape with asserted rank in <*result>. Otherwise an error is
// placed into `status`.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtLeast(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
TF_ShapeHandle* result, TF_Status* status);
// If <handle> has rank at most <rank>, or its rank is unknown, return OK and
// return the shape with asserted rank in <*result>. Otherwise an error is
// placed into `status`.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtMost(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
TF_ShapeHandle* result, TF_Status* status);
// Places a handle to the ith dimension of the given shape into *result.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i,
TF_DimensionHandle* result);
// Returns 1 if the given handle represents a known dimension.
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextDimValueKnown(
TF_ShapeInferenceContext* ctx, TF_DimensionHandle* handle);
// Returns in <*result> a sub-shape of <shape_handle>, with dimensions
// [start:end]. <start> and <end> can be negative, to index from the end of the
// shape. <start> and <end> are set to the rank of <shape_handle> if > rank of
// <shape_handle>.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSubshape(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t start,
int64_t end, TF_ShapeHandle* result, TF_Status* status);
// Places an unknown shape in all outputs for the given inference context. Used
// for shape inference functions with ops whose output shapes are unknown.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSetUnknownShape(
TF_ShapeInferenceContext* ctx, TF_Status* status);
// Returns whether the given handle represents a known dimension.
TF_CAPI_EXPORT extern int TF_DimensionHandleValueKnown(
TF_DimensionHandle* dim_handle);
// Returns the value of the given dimension.
TF_CAPI_EXPORT extern int64_t TF_DimensionHandleValue(
TF_DimensionHandle* dim_handle);
// Returns in <*result> the result of appending the dimensions of <second> to
// those of <first>.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextConcatenateShapes(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* first,
TF_ShapeHandle* second, TF_ShapeHandle* result, TF_Status* status);
// Frees the given shape handle.
TF_CAPI_EXPORT extern void TF_DeleteShapeHandle(TF_ShapeHandle* handle);
// Frees the given dimension handle.
TF_CAPI_EXPORT extern void TF_DeleteDimensionHandle(TF_DimensionHandle* handle);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_OPS_H_

159
tensorflow/c/ops_test.cc Normal file
View File

@ -0,0 +1,159 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(OpsTest, TestBasicOpRegistration) {
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
TF_OpDefinitionBuilderAddStringAttr(builder, "attr1");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddInput(builder, "input2", TF_UINT16);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT32);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Buffer* op_list_buffer = TF_GetAllOpList();
::tensorflow::OpList op_list;
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "SomeOp") {
ASSERT_EQ(2, op.input_arg_size());
ASSERT_EQ("input1", op.input_arg(0).name());
ASSERT_EQ(::tensorflow::DT_UINT8, op.input_arg(0).type());
ASSERT_EQ(1, op.attr_size());
ASSERT_EQ("string", op.attr(0).type());
found = true;
}
}
EXPECT_TRUE(found);
TF_DeleteStatus(status);
TF_DeleteBuffer(op_list_buffer);
}
void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
TF_ShapeHandle* handle = TF_NewShapeHandle();
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeInferenceContextSetOutput(ctx, 0, handle, status);
TF_DeleteShapeHandle(handle);
}
TEST(OpsTest, TestShapeInference_IdentityFunction) {
ShapeInferenceTestOp op("SomeTestOp");
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_ASSERT_OK(
shape_inference::ShapeInferenceTestutil::InferShapes(op, "[1,2]", "in0"));
TF_DeleteStatus(status);
}
// Creates an output whose shape is a vector of length
// TF_ShapeInferenceContextRank.
void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
TF_ShapeHandle* handle = TF_NewShapeHandle();
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeHandle* new_shape = TF_ShapeInferenceContextVectorFromSize(
ctx, TF_ShapeInferenceContextRank(ctx, handle));
TF_ShapeInferenceContextSetOutput(ctx, 0, new_shape, status);
TF_DeleteShapeHandle(handle);
TF_DeleteShapeHandle(new_shape);
}
TEST(OpsTest, TestShapeInference_VectorizeFunction) {
ShapeInferenceTestOp op("VectorizeTestOp");
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("VectorizeTestOp");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
op, "[4,5,9]", "[3]"));
TF_DeleteStatus(status);
}
TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
float values[] = {1, 2, 3, 4};
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
builder, "foo1", values, sizeof(values));
TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(builder, "foo2",
"my string");
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
TF_OpDefinitionBuilderSetIsAggregate(builder, true);
TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
std::string deprecation_msg = "use something else instead";
TF_OpDefinitionBuilderDeprecated(builder, 4, deprecation_msg.c_str());
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* op_list_buffer = TF_GetAllOpList();
::tensorflow::OpList op_list;
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") {
ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input());
ASSERT_EQ(4, op.deprecation().version());
ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
ASSERT_EQ(2, op.attr_size());
ASSERT_EQ("list(float)", op.attr(0).type());
AttrValue::ListValue l = op.attr(0).default_value().list();
ASSERT_EQ(1, l.f(0));
ASSERT_EQ(2, l.f(1));
ASSERT_EQ(3, l.f(2));
ASSERT_EQ(4, l.f(3));
ASSERT_EQ("string", op.attr(1).type());
ASSERT_EQ("my string", op.attr(1).default_value().s());
found = true;
}
}
ASSERT_TRUE(found);
TF_DeleteStatus(status);
TF_DeleteBuffer(op_list_buffer);
}
} // namespace
} // namespace tensorflow

View File

@ -203,6 +203,7 @@ tf_cc_test(
deps = [ deps = [
":ops", ":ops",
":scope", ":scope",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",

View File

@ -42,14 +42,19 @@ namespace {
const int kRightMargin = 79; const int kRightMargin = 79;
// Converts: // Converts:
// bazel-out/.../genfiles/(external/YYY/)?XX // bazel-out/.../(bin|genfiles)/(external/YYY/)?XX
// to: XX. // to: XX.
string GetPath(const string& dot_h_fname) { string GetPath(const string& dot_h_fname) {
auto pos = dot_h_fname.find("/genfiles/"); auto pos = dot_h_fname.find("/bin/");
string result = dot_h_fname; string result = dot_h_fname;
if (pos != string::npos) { if (pos != string::npos) {
// - 1 account for the terminating null character (\0) in "/genfiles/". // - 1 account for the terminating null character (\0) in "/genfiles/".
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1); result = dot_h_fname.substr(pos + sizeof("/bin/") - 1);
} else {
pos = dot_h_fname.find("/genfiles/");
if (pos != string::npos) {
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
}
} }
if (result.size() > sizeof("external/") && if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) { result.compare(0, sizeof("external/") - 1, "external/") == 0) {

View File

@ -531,4 +531,23 @@ Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
return InternalScope::NewScope(graph, status, refiner); return InternalScope::NewScope(graph, status, refiner);
} }
Status CreateOutputWithScope(string op_name,
absl::Span<const ::tensorflow::Input> inputs,
const Scope& scope, Output* output) {
TF_RETURN_IF_ERROR(scope.status());
const auto unique_name = scope.GetUniqueNameForOp(op_name);
auto builder = ::tensorflow::NodeBuilder(unique_name, op_name);
for (auto input : inputs) {
TF_RETURN_IF_ERROR(scope.status());
builder = builder.Input(input.node());
}
::tensorflow::Node* ret;
scope.UpdateBuilder(&builder);
TF_RETURN_IF_ERROR(scope.status());
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
TF_RETURN_IF_ERROR(scope.status());
*output = Output(ret, 0);
return Status::OK();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -255,6 +255,12 @@ struct CompositeOpScopes {
Scope last; Scope last;
}; };
// Creates a node of the given operation, with the given inputs, and assigns the
// result to output. This does not support the ability to add additional
// attributes.
Status CreateOutputWithScope(string op_name,
absl::Span<const ::tensorflow::Input> inputs,
const Scope& scope, Output* output);
/// @} /// @}
} // namespace tensorflow } // namespace tensorflow

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -145,4 +147,14 @@ TEST(ScopeTest, ControlDeps) {
EXPECT_EQ(c_c.control_deps().size(), 3); EXPECT_EQ(c_c.control_deps().size(), 3);
} }
TEST(ScopeTest, CreateOutput) {
Scope root = Scope::NewRootScope();
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
Output add;
ASSERT_TRUE(
CreateOutputWithScope("Add", {a, a}, root.WithOpName("add"), &add).ok());
EXPECT_EQ(add.node()->name(), "add");
EXPECT_EQ(add.node()->type_string(), "Add");
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,27 +18,41 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
import logging as _logging
import os as _os import os as _os
import sys as _sys import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper # Hook external TensorFlow modules.
_component_api_helper.package_hook( _current_module = _sys.modules[__name__]
parent_package_str=__name__, try:
child_package_str=('tensorboard.summary._tf.summary'), from tensorboard.summary._tf import summary
error_msg=( _current_module.__path__ = (
"Limited tf.compat.v2.summary API due to missing TensorBoard " [_module_util.get_parent_dir(summary)] + _current_module.__path__)
"installation")) except ImportError:
_component_api_helper.package_hook( _logging.warning(
parent_package_str=__name__, "Limited tf.compat.v2.summary API due to missing TensorBoard "
child_package_str=( "installation.")
'tensorflow_estimator.python.estimator.api._v2.estimator'))
_component_api_helper.package_hook( try:
parent_package_str=__name__, from tensorflow_estimator.python.estimator.api._v2 import estimator
child_package_str=('tensorflow.python.keras.api._v2.keras')) _current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
# We would like the following to work for fully enabling 2.0 in a 1.0 install: # We would like the following to work for fully enabling 2.0 in a 1.0 install:
# #

View File

@ -19,18 +19,30 @@ from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
import os as _os import os as _os
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper # Hook external TensorFlow modules.
_component_api_helper.package_hook( _current_module = _sys.modules[__name__]
parent_package_str=__name__, try:
child_package_str=( from tensorflow_estimator.python.estimator.api._v1 import estimator
'tensorflow_estimator.python.estimator.api._v1.estimator')) _current_module.__path__ = (
_component_api_helper.package_hook( [_module_util.get_parent_dir(estimator)] + _current_module.__path__)
parent_package_str=__name__, except ImportError:
child_package_str=('tensorflow.python.keras.api._v1.keras')) pass
try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable app.flags = flags # pylint: disable=undefined-variable

View File

@ -263,38 +263,23 @@ Status GenVariableMethods(const tf2xla::Config& config,
void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) { void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
set_arg_data({{I}}, data); set_arg_data({{I}}, data);
} }
)"; {{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() {
const tf2xla::Variable& var = config.variable(i - config.feed_size()); return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}}));
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
*methods += RewriteWithName(
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
} }
size_t num_results = ps.result().tuple_shapes_size(); {{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
int variable_num = -1; return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>(
for (int i = config.fetch_size(); i < num_results; ++i) { arg_data({{I}}))){{INDICES}};
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(AddRewritesForShape(
i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
string code = R"(
{{TYPE}}* var_{{NAME}}_data() {
return static_cast<{{TYPE}}*>(result_data({{I}}));
}
{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
result_data({{I}}))){{INDICES}};
} }
const {{TYPE}}* var_{{NAME}}_data() const { const {{TYPE}}* var_{{NAME}}_data() const {
return static_cast<const {{TYPE}}*>(result_data({{I}})); return static_cast<const {{TYPE}}*>(arg_data({{I}}));
} }
const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const { const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>( return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
result_data({{I}}))){{INDICES}}; arg_data({{I}}))){{INDICES}};
} }
)"; )";
do { const tf2xla::Variable& var = config.variable(i - config.feed_size());
++variable_num; rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
} while (config.variable(variable_num).readonly());
const tf2xla::Variable& var = config.variable(variable_num);
*methods += RewriteWithName( *methods += RewriteWithName(
var.name().empty() ? var.node_name() : var.name(), code, rewrites); var.name().empty() ? var.node_name() : var.name(), code, rewrites);
} }
@ -549,7 +534,8 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData; return *kStaticData;
} }
{{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {{CLASS}}(AllocMode alloc_mode =
AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {} : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
{{CLASS}}(const {{CLASS}}&) = delete; {{CLASS}}(const {{CLASS}}&) = delete;
@ -590,19 +576,29 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// buffers are managed internally, and may change after each call to Run. // buffers are managed internally, and may change after each call to Run.
{{METHODS_RESULT}} {{METHODS_RESULT}}
// Methods for managing variable buffers. Buffers are in row-major order. The // Methods for managing variable buffers. Buffers are in row-major order.
// input and output buffers may or may not be identical. //
// For read-write variables we generate the following methods:
// //
// void set_var_X_data(T* data) // void set_var_X_data(T* data)
// Sets the buffer for variable X. // Sets the buffer for variable X. Must be called before Run if the
// allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
// //
// T* var_X_data() // T* var_X_data()
// Returns the buffer of type T for variable X. // Returns the buffer of type T for variable X. If the allocation mode is
// RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
// buffer passed to set_var_X_data.
// //
// T& var_X(...dim indices...) // T& var_X(...dim indices...)
// Returns a reference to the value of type T for variable X, // Returns a reference to the value of type T for variable X,
// with dim indices specifying which value. No bounds checking is performed // with dim indices specifying which value. No bounds checking is performed
// on dim indices. // on dim indices.
//
// For readonly variables we generate the same set of methods, except that we
// use `const T` instead of `T`. We use `const T` to avoid erasing the
// constness of the buffer passed to `set_var_X_data` but the underlying
// buffer is not const (and thus the const can be safely const-cast'ed away)
// unless `set_var_X_data` is called with a pointer to constant storage.
{{METHODS_VARIABLE}} {{METHODS_VARIABLE}}
private: private:

View File

@ -91,7 +91,8 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData; return *kStaticData;
} }
MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) MyClass(AllocMode alloc_mode =
AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {} : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
MyClass(const MyClass&) = delete; MyClass(const MyClass&) = delete;
@ -214,60 +215,82 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
result_data(0)))[dim0][dim1]; result_data(0)))[dim0][dim1];
} }
// Methods for managing variable buffers. Buffers are in row-major order. The // Methods for managing variable buffers. Buffers are in row-major order.
// input and output buffers may or may not be identical. //
// For read-write variables we generate the following methods:
// //
// void set_var_X_data(T* data) // void set_var_X_data(T* data)
// Sets the buffer for variable X. // Sets the buffer for variable X. Must be called before Run if the
// allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
// //
// T* var_X_data() // T* var_X_data()
// Returns the buffer of type T for variable X. // Returns the buffer of type T for variable X. If the allocation mode is
// RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
// buffer passed to set_var_X_data.
// //
// T& var_X(...dim indices...) // T& var_X(...dim indices...)
// Returns a reference to the value of type T for variable X, // Returns a reference to the value of type T for variable X,
// with dim indices specifying which value. No bounds checking is performed // with dim indices specifying which value. No bounds checking is performed
// on dim indices. // on dim indices.
//
// For readonly variables we generate the same set of methods, except that we
// use `const T` instead of `T`. We use `const T` to avoid erasing the
// constness of the buffer passed to `set_var_X_data` but the underlying
// buffer is not const (and thus the const can be safely const-cast'ed away)
// unless `set_var_X_data` is called with a pointer to constant storage.
void set_var_myvar_readonly_data(const float* data) { void set_var_myvar_readonly_data(const float* data) {
set_arg_data(2, data); set_arg_data(2, data);
} }
const float* var_myvar_readonly_data() {
return static_cast<const float*>(arg_data(2));
}
const float& var_myvar_readonly() {
return (*static_cast<const float(*)[1]>(
arg_data(2)))[0];
}
const float* var_myvar_readonly_data() const {
return static_cast<const float*>(arg_data(2));
}
const float& var_myvar_readonly() const {
return (*static_cast<const float(*)[1]>(
arg_data(2)))[0];
}
void set_var_myvar_data(float* data) { void set_var_myvar_data(float* data) {
set_arg_data(3, data); set_arg_data(3, data);
} }
float* var_myvar_data() {
return static_cast<float*>(arg_data(3));
}
float& var_myvar() {
return (*static_cast<float(*)[1]>(
arg_data(3)))[0];
}
const float* var_myvar_data() const {
return static_cast<const float*>(arg_data(3));
}
const float& var_myvar() const {
return (*static_cast<const float(*)[1]>(
arg_data(3)))[0];
}
void set_var_myvar2_data(tensorflow::int32* data) { void set_var_myvar2_data(tensorflow::int32* data) {
set_arg_data(4, data); set_arg_data(4, data);
} }
float* var_myvar_data() {
return static_cast<float*>(result_data(1));
}
float& var_myvar() {
return (*static_cast<float(*)[1]>(
result_data(1)))[0];
}
const float* var_myvar_data() const {
return static_cast<const float*>(result_data(1));
}
const float& var_myvar() const {
return (*static_cast<const float(*)[1]>(
result_data(1)))[0];
}
tensorflow::int32* var_myvar2_data() { tensorflow::int32* var_myvar2_data() {
return static_cast<tensorflow::int32*>(result_data(2)); return static_cast<tensorflow::int32*>(arg_data(4));
} }
tensorflow::int32& var_myvar2(size_t dim0) { tensorflow::int32& var_myvar2(size_t dim0) {
return (*static_cast<tensorflow::int32(*)[5]>( return (*static_cast<tensorflow::int32(*)[5]>(
result_data(2)))[dim0]; arg_data(4)))[dim0];
} }
const tensorflow::int32* var_myvar2_data() const { const tensorflow::int32* var_myvar2_data() const {
return static_cast<const tensorflow::int32*>(result_data(2)); return static_cast<const tensorflow::int32*>(arg_data(4));
} }
const tensorflow::int32& var_myvar2(size_t dim0) const { const tensorflow::int32& var_myvar2(size_t dim0) const {
return (*static_cast<const tensorflow::int32(*)[5]>( return (*static_cast<const tensorflow::int32(*)[5]>(
result_data(2)))[dim0]; arg_data(4)))[dim0];
} }
private: private:

View File

@ -36,6 +36,7 @@ py_binary(
name = "make_test_graphs", name = "make_test_graphs",
testonly = 1, testonly = 1,
srcs = ["make_test_graphs.py"], srcs = ["make_test_graphs.py"],
python_version = "PY2",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",

View File

@ -83,7 +83,8 @@ TEST(TFCompileTest, Add) {
// Run tests that use set_argN_data separately, to avoid accidentally re-using // Run tests that use set_argN_data separately, to avoid accidentally re-using
// non-existent buffers. // non-existent buffers.
TEST(TFCompileTest, Add_SetArg) { TEST(TFCompileTest, Add_SetArg) {
AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); AddComp add(
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
int32 arg_x = 10; int32 arg_x = 10;
int32 arg_y = 32; int32 arg_y = 32;
@ -296,7 +297,7 @@ TEST(TFCompileTest, MatMul2_SetArg) {
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
foo::bar::MatMulComp matmul( foo::bar::MatMulComp matmul(
foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
matmul.set_thread_pool(&device); matmul.set_thread_pool(&device);
// Test using the set_argN_data() methods. // Test using the set_argN_data() methods.
@ -503,8 +504,36 @@ TEST(TFCompileTest, VariableSequentialUpdates) {
// This implements the recursion: // This implements the recursion:
// x[0] = 2.0 // x[0] = 2.0
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0) // x[n+1] = x[n] - 0.1*(x[n-1] + y)
VariableSequentialUpdatesComp fn; VariableSequentialUpdatesComp fn;
fn.var_x() = 2;
*const_cast<float*>(fn.var_y_data()) = 1;
fn.set_thread_pool(&device);
// First calculate x[3]
fn.Run();
EXPECT_NEAR(fn.var_x(), 1.187f, 1e-6);
const float y = 1;
fn.set_var_y_data(&y);
// Now const_cast<float*>(fn.var_y_data()) is not longer legal since we've set
// the buffer to point to a constant location.
// Then calculate x[6]
fn.Run();
EXPECT_NEAR(fn.var_x(), 0.594322f, 1e-6);
}
TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
// This implements the recursion:
// x[0] = 2.0
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
VariableSequentialUpdatesComp fn(
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
float x = 2; float x = 2;
float y = 1; float y = 1;
fn.set_var_x_data(&x); fn.set_var_x_data(&x);

View File

@ -174,6 +174,20 @@ def tf_library(
"'" + arg.replace("'", "'\\''") + "'" "'" + arg.replace("'", "'\\''") + "'"
for arg in (tfcompile_flags or []) for arg in (tfcompile_flags or [])
]) ])
# Do this before we append the `select` into `flags`, because doing so
# transforms `flags` into a variable of type `select`, and we can't call
# `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
# Pass --target_cpu=haswell to tfcompile if compiling for Haswell (bazel
# build --cpu=haswell). We put it at the beginning of the flags list so
# that tfcompile_flags can override if if desired.
flags = select({
"//tools/target_cpu:haswell": "--target_cpu=haswell ",
"//conditions:default": "",
}) + flags
if enable_xla_hlo_profiling: if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile" profiling_flag = "--xla_hlo_profile"
else: else:
@ -251,7 +265,6 @@ def tf_library(
# The cc_library rule packaging up the header and object file, and needed # The cc_library rule packaging up the header and object file, and needed
# kernel implementations. # kernel implementations.
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library( native.cc_library(
name = name, name = name,
srcs = [function_object_file, metadata_object_file], srcs = [function_object_file, metadata_object_file],

View File

@ -17,15 +17,14 @@ package_group(
package( package(
default_visibility = [ default_visibility = [
":internal", ":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
], ],
) )
# NB! Removing the cc_header_only_library import breaks the OSS build since load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
# copybara injects some build rules that use it.
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
# Target that bundles up the XLA CPU and GPU JIT devices. # Target that bundles up the XLA CPU and GPU JIT devices.
@ -212,6 +211,7 @@ cc_library(
"//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:optional_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op", "//tensorflow/core/kernels/data:prefetch_dataset_op",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
@ -223,6 +223,7 @@ cc_library(
name = "shape_inference_helpers", name = "shape_inference_helpers",
srcs = ["shape_inference_helpers.cc"], srcs = ["shape_inference_helpers.cc"],
hdrs = ["shape_inference_helpers.h"], hdrs = ["shape_inference_helpers.h"],
visibility = [":friends"],
deps = ["//tensorflow/core:graph"], deps = ["//tensorflow/core:graph"],
) )
@ -256,6 +257,11 @@ cc_library(
name = "xla_launch_util", name = "xla_launch_util",
srcs = ["xla_launch_util.cc"], srcs = ["xla_launch_util.cc"],
hdrs = ["xla_launch_util.h"], hdrs = ["xla_launch_util.h"],
# TODO(skyewm): remove this once XlaAllocator is factored out.
visibility = [
":internal",
"//tensorflow/compiler/xla/python:__pkg__",
],
deps = [ deps = [
":common", ":common",
":xla_compilation_cache", ":xla_compilation_cache",
@ -265,7 +271,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
@ -273,6 +278,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -468,6 +474,9 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
@ -518,8 +527,9 @@ cc_library(
"partially_decluster_pass.h", "partially_decluster_pass.h",
], ],
deps = [ deps = [
"compilability_check_util",
":common", ":common",
":device_info_cache", ":device_util",
":encapsulate_util", ":encapsulate_util",
":flags", ":flags",
":resource_operation_safety_analysis", ":resource_operation_safety_analysis",
@ -581,21 +591,35 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
], ],
) )
cc_library( cc_library(
name = "device_info_cache", name = "device_util",
srcs = ["device_info_cache.cc"], srcs = ["device_util.cc"],
hdrs = ["device_info_cache.h"], hdrs = ["device_util.h"],
deps = [ deps = [
":xla_cluster_util",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "device_util_test",
srcs = ["device_util_test.cc"],
deps = [
":device_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
], ],
) )
@ -661,6 +685,7 @@ tf_cc_test(
"introduce_floating_point_jitter_pass_test.cc", "introduce_floating_point_jitter_pass_test.cc",
"mark_for_compilation_pass_test.cc", "mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc", "partially_decluster_pass_test.cc",
"rearrange_function_argument_pass_test.cc",
], ],
deps = [ deps = [
":common", ":common",
@ -681,6 +706,7 @@ tf_cc_test(
"//tensorflow/cc:scope", "//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops", "//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
"//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
@ -764,6 +790,34 @@ tf_cc_test(
], ],
) )
cc_library(
name = "compilability_check_util",
srcs = ["compilability_check_util.cc"],
hdrs = ["compilability_check_util.h"],
deps = [
":common",
":device_util",
":flags",
":resource_operation_safety_analysis",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
tf_custom_op_py_library( tf_custom_op_py_library(
name = "xla_ops_py", name = "xla_ops_py",
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -25,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/logging_ops.h" #include "tensorflow/cc/ops/logging_ops.h"
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
@ -231,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
} }
// Returns true (into `result`) if a node placed on `device` must be compiled. // Returns true (into `result`) if a node placed on `device` must be compiled.
Status DeviceRequiresCompilation(const string& device, bool* result) { Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
DeviceType device_type(""); jit::DeviceId device, bool* result) {
TF_RETURN_IF_ERROR(DeviceToDeviceType(device, &device_type)); const XlaOpRegistry::DeviceRegistration* registration =
const XlaOpRegistry::DeviceRegistration* registration = nullptr; device_info_cache.GetCompilationDevice(device);
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return errors::Internal("Could not find compilation device ",
device_type.type());
}
*result = registration->autoclustering_policy == *result = registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways; XlaOpRegistry::AutoclusteringPolicy::kAlways;
return Status::OK(); return Status::OK();
@ -291,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall(
return Status::OK(); return Status::OK();
} }
Status InferDeviceForCluster(Node* n, const string& function_name, xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
const FunctionLibraryDefinition& flib_def, jit::DeviceInfoCache* device_info_cache, Node* n,
string* result) { const string& function_name, const FunctionLibraryDefinition& flib_def) {
const FunctionDef* func_def = flib_def.Find(function_name); const FunctionDef* func_def = flib_def.Find(function_name);
TF_RET_CHECK(func_def) << "Could not find " << function_name; TF_RET_CHECK(func_def) << "Could not find " << function_name;
std::set<string> device_names; jit::DeviceSet device_set;
for (const NodeDef& ndef : func_def->node_def()) { for (const NodeDef& ndef : func_def->node_def()) {
VLOG(3) << ndef.DebugString(); VLOG(3) << ndef.DebugString();
if (!ndef.device().empty()) { if (!ndef.device().empty()) {
device_names.insert(ndef.device()); TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(ndef.device()));
device_set.Insert(device_id);
} }
} }
@ -309,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name,
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device // TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
// assignment when constant folding. We should fix EncapsulateSubgraphsPass // assignment when constant folding. We should fix EncapsulateSubgraphsPass
// instead. // instead.
device_names.insert(n->assigned_device_name()); TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(n->assigned_device_name()));
device_set.Insert(device_id);
} }
std::vector<string> device_names_vector; TF_ASSIGN_OR_RETURN(jit::DeviceId result,
absl::c_copy(device_names, std::back_inserter(device_names_vector)); PickDeviceForXla(*device_info_cache, device_set,
/*allow_mixing_unknown_and_cpu=*/true));
Status s = PickDeviceForXla(device_names_vector, true, result); VLOG(2) << "For " << function_name << " PickDeviceForXla("
if (s.ok()) { << device_info_cache->DebugString(device_set) << ") -> "
VLOG(2) << "For " << function_name << " PickDeviceForXla(" << device_info_cache->GetNameFor(result);
<< absl::StrJoin(device_names_vector, ", ") << ") -> " << *result; return result;
}
return s;
} }
Status ReplaceNodeWithXlaCompileAndXlaRun( Status ReplaceNodeWithXlaCompileAndXlaRun(
jit::DeviceInfoCache* device_info_cache,
const GraphOptimizationPassOptions& options, const GraphOptimizationPassOptions& options,
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
bool insert_print_nodes, Graph* g, Node* n) { bool insert_print_nodes, Graph* g, Node* n) {
XlaClusterInfo cluster_info; XlaClusterInfo cluster_info;
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
string device; TF_ASSIGN_OR_RETURN(
TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(), jit::DeviceId device,
flib_def, &device)); InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
flib_def));
bool requires_compilation; bool requires_compilation;
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation)); TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
&requires_compilation));
if (!lazy_compilation_enabled) { if (!lazy_compilation_enabled) {
requires_compilation = true; requires_compilation = true;
} }
string device_name_str = string(device_info_cache->GetNameFor(device));
Status status; Status status;
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
.NewSubScope(n->name()) .NewSubScope(n->name())
.WithDevice(n->requested_device()) .WithDevice(n->requested_device())
.WithAssignedDevice(device); .WithAssignedDevice(device_name_str);
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
/*constants=*/cluster_info.constant_inputs, /*constants=*/cluster_info.constant_inputs,
@ -435,14 +442,16 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
bool lazy_compilation_enabled = bool lazy_compilation_enabled =
enable_lazy_compilation_ enable_lazy_compilation_
? *enable_lazy_compilation_ ? *enable_lazy_compilation_
: GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation; : GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
bool insert_print_nodes = bool insert_print_nodes =
GetBuildXlaOpsPassFlags().tf_xla_print_cluster_outputs; GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
jit::DeviceInfoCache device_info_cache;
for (Node* n : xla_compiled_kernels) { for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
options, *options.flib_def, lazy_compilation_enabled, &device_info_cache, options, *options.flib_def,
insert_print_nodes, graph, n)); lazy_compilation_enabled, insert_print_nodes, graph, n));
} }
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {

View File

@ -0,0 +1,277 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include <atomic>
#include <deque>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
bool HasResourceInput(const Node& node) {
return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
}
} // anonymous namespace
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
// IsCompilableCall().
if (node.type_string() == "SymbolicGradient") return false;
if (node.type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
// registered Const KernelDef says that it does, to support no-op Assert for
// tfcompile.
const AttrValue* attr = node.attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
return false;
}
}
// XLA does not offer guaranteed aliasing between the input and output of the
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
return false;
}
return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok();
}
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool RecursiveCompilabilityChecker::IsCompilableWhile(
const Node& while_node, int depth, FunctionLibraryRuntime* lib_runtime) {
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
if (!status.ok()) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": missing 'cond' attribute on While node.";
return false;
}
const string cond_func = name_attr->name();
call.set_name("while_cond");
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, depth + 1, lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
if (!status.ok()) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": missing 'body' attribute on While node.";
return false;
}
const string body_func = name_attr->name();
call.set_name("while_body");
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, depth + 1, lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop body: " << body_func;
return false;
}
return true;
}
// Tests whether 'call_def' is a call to a completely compilable function.
// Every operator in the function must be compilable for a function to be
// compilable.
bool RecursiveCompilabilityChecker::IsCompilableCall(
const NodeDef& call_def, int depth, FunctionLibraryRuntime* lib_runtime) {
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Rejecting " << call_def.op()
<< ": function depth limit exceeded.";
return false;
}
FunctionLibraryRuntime::Handle handle;
Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle);
if (!status.ok()) {
VLOG(2) << "Rejecting " << call_def.DebugString()
<< ": could not instantiate: " << status;
return false;
}
auto release_handle_on_return = gtl::MakeCleanup(
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
for (Node* node : fbody->graph->op_nodes()) {
if (!IsCompilableNode(*node, depth + 1, lib_runtime)) {
return false;
}
}
return true;
}
bool LogNotCompilableAndReturn(const Node& node,
absl::string_view reason = "") {
VLOG(3) << "Not clustering " << node.name() << " (op " << node.type_string()
<< ")" << (reason.empty() ? "" : ": ") << reason;
return false;
}
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) {
// b/127344411: SelfAdjointEigV2 and Svd precision issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd";
}
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) {
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd" || node.type_string() == "Qr";
}
bool RecursiveCompilabilityChecker::IsCompilableNode(
const Node& node, int depth, FunctionLibraryRuntime* lib_runtime) {
if (node.IsSource() || node.IsSink()) {
return LogNotCompilableAndReturn(node, "source or sink node");
}
// _Arg nodes in a top-level function represent feeds and _Retval nodes in a
// top-level function represent fetches.
if (depth == 0 &&
(node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
return LogNotCompilableAndReturn(node, "depth is 0");
}
if (node.attrs().Find("_scoped_allocator") ||
node.attrs().Find("_forward_from")) {
// TODO(b/128858118): XLA does not support _scoped_allocator and
// _forward_from.
return LogNotCompilableAndReturn(
node, "_scoped_allocator or _forward_from attribute");
}
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
if (!IsCompilableCall(node.def(), depth + 1, lib_runtime)) {
return LogNotCompilableAndReturn(node, "unsupported function");
}
} else if (!HasXLAKernel(node)) {
return LogNotCompilableAndReturn(node, "unsupported op");
}
if (node.type_string() == "While" &&
!IsCompilableWhile(node, depth + 1, lib_runtime)) {
return LogNotCompilableAndReturn(node, "unsupported while");
}
if (!op_filter_.allow_stateful_rng_ops &&
IsStatefulRandomOp(node.type_string())) {
return LogNotCompilableAndReturn(node, "stateful random op");
}
if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
return LogNotCompilableAndReturn(node);
}
if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
IsAssertOrCheckNumerics(node.type_string())) {
return LogNotCompilableAndReturn(node, "Assert or CheckNumerics");
}
if (!op_filter_.allow_ops_producing_or_consuming_variant &&
OpProducesOrConsumesVariant(node)) {
return LogNotCompilableAndReturn(node, "DT_VARIANT producer/consumer");
}
if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
return LogNotCompilableAndReturn(node, "Stack op");
}
if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
return LogNotCompilableAndReturn(node, "TensorArray op");
}
if (!op_filter_.allow_resource_ops_in_called_functions && depth > 0 &&
HasResourceInput(node)) {
return LogNotCompilableAndReturn(node,
"resource variable op in called function");
}
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
return LogNotCompilableAndReturn(node, "operation with correctness issues");
}
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
return LogNotCompilableAndReturn(node, "slow operation");
}
return true;
}
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
const XlaOpRegistry::DeviceRegistration& registration) {
RecursiveCompilabilityChecker::OperationFilter op_filter;
op_filter.allow_resource_ops_in_called_functions =
registration.cluster_resource_variable_ops_unsafely;
op_filter.allow_stack_ops = registration.cluster_stack_ops;
op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
op_filter.allow_control_trigger = registration.cluster_control_trigger;
op_filter.allow_eliding_assert_and_checknumerics_ops =
registration.elide_assert_and_checknumerics;
op_filter.allow_ops_producing_or_consuming_variant =
registration.cluster_variant_ops;
op_filter.allow_slow_and_inaccurate_ops =
registration.cluster_slow_and_inaccurate_ops;
return op_filter;
}
} // namespace tensorflow

View File

@ -0,0 +1,175 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
#include "absl/algorithm/container.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
// Checks whether a TF node can be compiled or not. "Recursive" as in for call
// and functional while nodes it recursively checks whether the callee functions
// can be compiled.
class RecursiveCompilabilityChecker {
public:
// Aggregates information about what kinds of ops are allowed.
struct OperationFilter { // TODO(lzr): Add AllowEverything() helper.
// Whether resource variable ops are allowed are allowed in callees. We do
// not allow resource variable ops in called functions (either as direct TF
// calls or as higher order control flow ops) because we do not yet model
// their memory effects in jit/resource_variable_safety_analysis.
bool allow_resource_ops_in_called_functions;
// Whether Stack operations are allowed. We avoid auto-clustering Stack
// operations in general because we do not support snapshotting them.
//
// TODO(b/112837194): This restriction can be lifted with some work.
bool allow_stack_ops;
// Whether TensorArray operations are allowed. We avoid auto-clustering
// TensorArray operations in general because we do not support snapshotting
// them.
//
// TODO(b/112837194): This restriction can be lifted with some work.
bool allow_tensor_array_ops;
// Whether stateful RNG ops are allowed. XLA's RNG does not have the same
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
// auto-clustering stateful RNG ops.
bool allow_stateful_rng_ops;
// TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound
// to cluster ControlTrigger because of how we use deadness analysis.
bool allow_control_trigger;
// Whether it is okay to "cluster" Assert and CheckNumerics by simply
// removing them (they're not removed during clustering, but their
// XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so
// that the user is not surprised when XLA is implicitly enabled. If the
// user explicitly specifies to use XLA, it is fine to resort to a dummy
// implementation. Currently Assert and CheckNumerics ops have dummy XLA
// implementations.
bool allow_eliding_assert_and_checknumerics_ops;
// Whether ops that produce or consume DT_VARIANT values are allowed. We
// don't auto-cluster these ops because we don't yet support live-in or
// live-out DT_VARIANT values.
bool allow_ops_producing_or_consuming_variant;
// Whether ops known to be slow or to have correctness issues should be
// auto-clustered.
bool allow_slow_and_inaccurate_ops;
};
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
const DeviceType* jit_device_type)
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
// Returns true if `node` can be compiled by XLA.
bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) {
return IsCompilableNode(node, /*depth=*/0, lib_runtime);
}
// Returns true if `call_def` can be compiled by XLA. It is assumed that
// `call_def` is a call operation.
bool IsCompilableCall(const NodeDef& call_def,
FunctionLibraryRuntime* lib_runtime) {
return IsCompilableCall(call_def, /*depth=*/0, lib_runtime);
}
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
// due to performance or correctness concerns).
bool OpIsInaccurate(const Node& node);
bool OpIsSlow(const Node& node);
private:
bool IsCompilableNode(const Node& node, int depth,
FunctionLibraryRuntime* lib_runtime);
bool IsCompilableCall(const NodeDef& call_def, int depth,
FunctionLibraryRuntime* lib_runtime);
bool IsCompilableWhile(const Node& while_node, int depth,
FunctionLibraryRuntime* lib_runtime);
bool IsStackOp(const Node& node) {
const XlaResourceOpInfo* op_info =
GetResourceOpInfoForOp(node.type_string());
return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
}
bool IsTensorArrayOp(const Node& node) {
const XlaResourceOpInfo* op_info =
GetResourceOpInfoForOp(node.type_string());
return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
}
bool IsAssertOrCheckNumerics(absl::string_view op_name) {
return op_name == "Assert" || op_name == "CheckNumerics";
}
bool IsStatefulRandomOp(absl::string_view op_name) {
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
op_name == "TruncatedNormal" || op_name == "Multinomial";
}
bool OpProducesOrConsumesVariant(const Node& node) {
auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
return absl::c_any_of(node.input_types(), is_variant) ||
absl::c_any_of(node.output_types(), is_variant);
}
bool HasXLAKernel(const Node& node);
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 10;
const OperationFilter& op_filter_;
const DeviceType& jit_device_type_;
};
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
const XlaOpRegistry::DeviceRegistration& registration);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_

View File

@ -371,7 +371,8 @@ class PredicateFactory {
Predicate** predicate) { Predicate** predicate) {
TensorId tensor_id(node->name(), output_idx); TensorId tensor_id(node->name(), output_idx);
bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL; bool is_boolean_tensor =
BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
TF_RET_CHECK(!must_be_true || is_boolean_tensor); TF_RET_CHECK(!must_be_true || is_boolean_tensor);
if (node->type_string() == "Const" && must_be_true) { if (node->type_string() == "Const" && must_be_true) {

View File

@ -1067,5 +1067,25 @@ TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false"); EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
} }
TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
Scope root = Scope::NewRootScope().ExitOnError();
Output condition_ref_var =
ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
FixupSourceAndSinkEdges(root.graph());
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -1,61 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/device_info_cache.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
using xla::StatusOr;
StatusOr<const XlaOpRegistry::DeviceRegistration*>
DeviceInfoCache::GetCompilationDevice(absl::string_view device_name) {
auto it = device_to_device_registration_.find(device_name);
if (it != device_to_device_registration_.end()) {
return it->second;
}
string device_name_str = string(device_name);
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
GetDeviceTypeFor(device_name_str));
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
registration = nullptr;
}
device_to_device_registration_.insert(
{std::move(device_name_str), registration});
return registration;
}
StatusOr<std::reference_wrapper<const DeviceType>>
DeviceInfoCache::GetDeviceTypeFor(absl::string_view device_name) {
auto it = device_to_device_type_.find(device_name);
if (it != device_to_device_type_.end()) {
return std::cref(*it->second);
}
string device_name_str = string(device_name);
auto device_type = absl::make_unique<DeviceType>("");
TF_RETURN_IF_ERROR(DeviceToDeviceType(device_name_str, device_type.get()));
it = device_to_device_type_
.insert({std::move(device_name_str), std::move(device_type)})
.first;
return std::cref(*it->second);
}
} // namespace tensorflow

View File

@ -1,45 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
#include <functional>
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
// Caches some miscellaneous information about TF devices. Thread compatible.
class DeviceInfoCache {
public:
xla::StatusOr<const XlaOpRegistry::DeviceRegistration*> GetCompilationDevice(
absl::string_view device_name);
xla::StatusOr<std::reference_wrapper<const DeviceType>> GetDeviceTypeFor(
absl::string_view device_name);
private:
absl::flat_hash_map<string, const XlaOpRegistry::DeviceRegistration*>
device_to_device_registration_;
absl::flat_hash_map<string, std::unique_ptr<DeviceType>>
device_to_device_type_;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_

View File

@ -0,0 +1,206 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/device_util.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
namespace jit {
using xla::StatusOr;
void DeviceSet::Insert(DeviceId device_id) {
int word_index = device_id.id() / kWordSize;
int bit_index = device_id.id() % kWordSize;
if (word_index >= storage_.size()) {
storage_.resize(word_index + 1, 0);
}
storage_[word_index] |= (1ull << bit_index);
}
void DeviceSet::UnionWith(const DeviceSet& other) {
if (other.storage_.size() > storage_.size()) {
storage_.resize(other.storage_.size(), 0);
}
for (int i = 0; i < other.storage_.size(); i++) {
storage_[i] |= other.storage_[i];
}
}
bool DeviceSet::IsEmpty() const {
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
}
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
TF_RET_CHECK(!name.empty());
auto it = name_to_id_.find(name);
if (it != name_to_id_.end()) {
return it->second;
}
int new_id = names_.size();
names_.push_back(string(name));
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
DeviceType* device_type = id_to_device_type_.back().get();
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
name_to_id_.emplace(string(name), DeviceId(new_id));
const XlaOpRegistry::DeviceRegistration* compilation_device;
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
&compilation_device)) {
compilation_device = nullptr;
}
id_to_compilation_device_.push_back(compilation_device);
return DeviceId(new_id);
}
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
std::vector<string> names;
device_set.ForEach([&](DeviceId device_id) {
names.push_back(string(GetNameFor(device_id)));
return false;
});
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
}
} // namespace jit
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}
xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
bool failure_to_pick_is_error) {
#define FAILED_TO_PICK_DEVICE(failing_status) \
do { \
if (failure_to_pick_is_error) { \
return failing_status; \
} else { \
return {absl::nullopt}; \
} \
} while (false)
absl::optional<jit::DeviceId> maybe_gpu_device;
absl::optional<jit::DeviceId> maybe_cpu_device;
absl::optional<jit::DeviceId> maybe_unknown_device;
bool multiple_cpu_devices = false;
bool multiple_gpu_devices = false;
bool multiple_unknown_devices = false;
devices.ForEach([&](jit::DeviceId device) {
if (device_info_cache.IsGpu(device)) {
if (maybe_gpu_device) {
multiple_gpu_devices = true;
return false;
}
maybe_gpu_device = device;
} else if (device_info_cache.IsCpu(device)) {
if (maybe_cpu_device) {
multiple_cpu_devices = true;
return false;
}
maybe_cpu_device = device;
} else {
if (maybe_unknown_device) {
multiple_unknown_devices = true;
return false;
}
maybe_unknown_device = device;
}
return true;
});
if (multiple_cpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_gpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_unknown_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
}
if (maybe_unknown_device && maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and GPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_gpu_device)));
}
if (!allow_mixing_unknown_and_cpu) {
if (maybe_unknown_device && maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and CPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_cpu_device)));
}
}
if (maybe_gpu_device) {
return {*maybe_gpu_device};
} else if (maybe_unknown_device) {
return {*maybe_unknown_device};
} else if (maybe_cpu_device) {
return {*maybe_cpu_device};
}
FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
#undef FAILED_TO_PICK_DEVICE
}
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/true));
return *device_id;
}
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
return PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/false);
}
} // namespace tensorflow

View File

@ -0,0 +1,211 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
#include <functional>
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace jit {
// Instances of DeviceId represent TensorFlow devices as integers.
//
// This helps avoid having to manipulate device names as strings when
// auto-clustering.
class DeviceId {
public:
DeviceId(DeviceId&&) = default;
DeviceId(const DeviceId&) = default;
DeviceId& operator=(const DeviceId&) = default;
bool operator==(const DeviceId& other) const { return id() == other.id(); }
bool operator!=(const DeviceId& other) const { return !(*this == other); }
private:
int id_;
explicit DeviceId(int id) : id_(id) {}
int id() const { return id_; }
friend class DeviceInfoCache;
friend class DeviceSet;
};
// A set of DeviceIds, represented as a bitmap.
class DeviceSet {
public:
void Insert(DeviceId device_id);
void UnionWith(const DeviceSet& other);
bool IsEmpty() const;
// Calls `func` on each DeviceId in the set. Stops iterating early if `func`
// return false.
//
// TODO(sanjoy): Change this to take a typed std::function if that's
// performance neutral.
template <typename FnTy>
void ForEach(FnTy func) const {
// This is really a poor man's iterator, we should consider writing a proper
// iterator if this ends up being used widely.
for (int word_index = 0; word_index < storage_.size(); word_index++) {
uint64 word = storage_[word_index];
while (word != 0) {
uint64 only_lowest_bit_set = word & -word;
// The number of trailing zeros in a non-zero word is the index of the
// least significant 1.
int bit_index = ctz_uint64(word);
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
return;
}
word ^= only_lowest_bit_set;
}
}
}
private:
static int ctz_uint64(uint64 x) {
DCHECK_NE(x, 0);
#ifdef __GNUC__
return __builtin_ctzl(x);
#else
int result = 0u;
while ((x & 1u) == 0u) {
x >>= 1;
++result;
}
return result;
#endif
}
absl::InlinedVector<uint64, 1> storage_;
const int kWordSize = 64;
};
// Caches some miscellaneous information about TF devices. Thread compatible.
class DeviceInfoCache {
public:
bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
absl::string_view GetNameFor(DeviceId device) const {
return names_[device.id()];
}
xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
DeviceRegistration* GetCompilationDevice(DeviceId device) const {
return id_to_compilation_device_[device.id()];
}
xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
absl::string_view name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
return GetCompilationDevice(device_id);
}
const DeviceType& GetDeviceTypeFor(DeviceId device) const {
return *id_to_device_type_[device.id()];
}
using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
absl::string_view device_name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
return std::cref(*id_to_device_type_[device_id.id()]);
}
string DebugString(const DeviceSet& device_set) const;
private:
absl::flat_hash_map<string, DeviceId> name_to_id_;
// These fields are populated for a device in GetIdFor, *before* we give out a
// DeviceId.
std::vector<const XlaOpRegistry::DeviceRegistration*>
id_to_compilation_device_;
std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
std::vector<string> names_;
std::vector<bool> is_cpu_;
std::vector<bool> is_gpu_;
};
} // namespace jit
// Returns the DeviceType corresponding to 'device'.
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
// Picks the device for which XLA should compile a cluster that contains
// operations placed in devices in `devices`. For instance a cluster that
// contains operations solely placed on the CPU will be compiled into a CPU
// executable by XLA, whereas a cluster that contains operations placed on the
// CPU and also operations placed on the GPU will be compiled into a GPU
// executable.
//
// Returns a non-OK Status if no unambiguous choice of device exists.
//
// We choose the device using the following rules:
//
// - It is an error for `device_names` to contain more than one device of the
// same type.
// - GPU is preferred over CPU.
// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are
// preferred over CPU.
// - XLA devices count as "unrecognized devices".
//
// This set of rules above implicitly assume that XLA:GPU can compile all
// operations in the cluster that XLA:CPU can compile, and if
// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile
// all operations in the cluster that XLA:CPU can compile.
//
// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of
// the following things:
//
// - Let MarkForCompilationPass not inject CPU-placed operations into clusters
// that will run on unknown devices (because the unknown XLA backend may not
// support every operation supported by CPU).
// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster
// that contains nodes placed on both the CPU and on unknown devices. In this
// case it is the responsibility of the optimization pass that injected the
// CPU nodes into the cluster to ensure that these nodes can be compiled by
// the unknown XLA backend.
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
// This is like `PickDeviceForXla` except that it returns nullopt (instead of a
// non-OK Status) if no unambiguous choice of device exists.
//
// We return a failing Status for errors unrelated to the device choice
// algorithm itself.
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_

View File

@ -0,0 +1,132 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> device_names,
string* result) {
jit::DeviceInfoCache cache;
jit::DeviceSet device_set;
for (absl::string_view name : device_names) {
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name));
device_set.Insert(device_id);
}
TF_ASSIGN_OR_RETURN(
jit::DeviceId result_id,
PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu));
*result = string(cache.GetNameFor(result_id));
return Status::OK();
}
void CheckPickDeviceResult(absl::string_view expected_result,
bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) {
string result;
TF_ASSERT_OK(PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result))
<< "inputs = [" << absl::StrJoin(inputs, ", ")
<< "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu
<< ", expected_result=" << expected_result;
EXPECT_EQ(result, expected_result);
}
void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) {
string result;
EXPECT_FALSE(
PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result).ok());
}
const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0";
const char* kYPU0 = "/job:localhost/replica:0/task:0/device:YPU:0";
const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
TEST(PickDeviceForXla, UniqueDevice) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
}
TEST(PickDeviceForXla, DeviceOrder) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});
CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0});
}
TEST(PickDeviceForXla, MultipleUnknownDevices) {
CheckPickDeviceHasError(false, {kXPU0, kYPU0});
}
TEST(PickDeviceForXla, GpuAndUnknown) {
CheckPickDeviceHasError(false, {kGPU0, kXPU1});
}
TEST(PickDeviceForXla, UnknownAndCpu) {
CheckPickDeviceHasError(false, {kXPU0, kCPU1});
}
TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
CheckPickDeviceHasError(true, {kCPU0, kCPU1});
CheckPickDeviceHasError(false, {kCPU0, kCPU1});
CheckPickDeviceHasError(false, {kGPU0, kGPU1});
CheckPickDeviceHasError(false, {kXPU0, kXPU1});
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
}
void SimpleRoundTripTestForDeviceSet(int num_devices) {
jit::DeviceSet device_set;
jit::DeviceInfoCache device_info_cache;
std::vector<string> expected_devices, actual_devices;
for (int i = 0; i < num_devices; i++) {
string device_name =
absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i);
TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id,
device_info_cache.GetIdFor(device_name));
device_set.Insert(device_id);
expected_devices.push_back(device_name);
}
device_set.ForEach([&](jit::DeviceId device_id) {
actual_devices.push_back(string(device_info_cache.GetNameFor(device_id)));
return true;
});
EXPECT_EQ(expected_devices, actual_devices);
}
TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); }
TEST(DeviceSetTest, SimpleRoundTrip_Small) {
SimpleRoundTripTestForDeviceSet(8);
}
TEST(DeviceSetTest, SimpleRoundTrip_Large) {
SimpleRoundTripTestForDeviceSet(800);
}
} // namespace
} // namespace tensorflow

View File

@ -2497,8 +2497,6 @@ Status EncapsulateSubgraphsInFunctions(
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out, bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
FunctionLibraryDefinition* library) { FunctionLibraryDefinition* library) {
Status s;
Encapsulator encapsulator(std::move(group_attribute), Encapsulator encapsulator(std::move(group_attribute),
std::move(outside_compilation_attribute), std::move(outside_compilation_attribute),
&graph_in); &graph_in);

View File

@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
XlaClusterInfo{func, func_name_attrs, xla_computation_node, XlaClusterInfo{func, func_name_attrs, xla_computation_node,
std::map<string, int>{}}); std::map<string, int>{}});
} }
bool modified;
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
graph_out.get(), flr, lib_def.get()); graph_out.get(), flr, lib_def.get(), &modified);
if (!s.ok()) return s; if (!s.ok()) return s;
GraphDef graphdef_out; GraphDef graphdef_out;
@ -1105,7 +1106,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"shapes", absl::Span<const DataType>({})}, {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O2"}, {"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}, absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"F"}}, {"F"}},
{{"outside_compilation_O1_host_compute"}, {{"outside_compilation_O1_host_compute"},
"XlaHostCompute", "XlaHostCompute",
@ -1985,7 +1988,9 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span<const TensorShapeProto>({})}, {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}, {"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}}, absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}},
}, },
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}}); {"h_0_retval_retval", "H:o:0"}});
@ -2110,7 +2115,9 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span<const TensorShapeProto>({})}, {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}, {"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}}, absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}},
{{"outside_compilation_O1_host_compute"}, {{"outside_compilation_O1_host_compute"},
"XlaHostCompute", "XlaHostCompute",
{"D:o:0"}, {"D:o:0"},
@ -2258,7 +2265,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"shapes", absl::Span<const TensorShapeProto>({})}, {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}, {"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}, absl::Span<const string>(
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
{}}, {}},
{{"outside_compilation_O3_host_compute"}, {{"outside_compilation_O3_host_compute"},
"XlaHostCompute", "XlaHostCompute",
@ -2271,7 +2279,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"shapes", absl::Span<const TensorShapeProto>({})}, {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O3"}, {"_outside_compilation_subgraph", "O3"},
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}, absl::Span<const string>({"_xla_token_arg_node",
"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"})}},
{}}}, {}}},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}}); {"h_0_retval_retval", "H:o:0"}});

View File

@ -14,9 +14,12 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/encapsulate_util.h"
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/jit/shape_inference.h"
@ -24,6 +27,9 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using stream_executor::port::StatusOr;
namespace tensorflow { namespace tensorflow {
@ -333,6 +339,43 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
return Status::OK(); return Status::OK();
} }
StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name) {
auto cluster_deps = absl::make_unique<
absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
for (const Edge* e : g->edges()) {
auto src_outside_compilation =
GetStringAttr(*e->src(), outside_compilation_attr_name);
auto dst_outside_compilation =
GetStringAttr(*e->dst(), outside_compilation_attr_name);
if (src_outside_compilation && dst_outside_compilation &&
*src_outside_compilation != *dst_outside_compilation) {
auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
if (dst_deps_it == cluster_deps->end()) {
cluster_deps->insert(std::make_pair(
*dst_outside_compilation,
absl::flat_hash_set<string>({*src_outside_compilation})));
} else {
dst_deps_it->second.insert(*src_outside_compilation);
}
}
}
auto cluster_deps_ordered =
absl::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
std::vector<string> ordered_deps(it->second.begin(), it->second.end());
std::sort(ordered_deps.begin(), ordered_deps.end());
cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
}
return std::move(cluster_deps_ordered);
}
Status PreprocessEdgesBetweenOutsideCompilations( Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) { Graph* g, const string& outside_compilation_attr_name) {
// Remove edges from source node to outside compilation nodes, and edges // Remove edges from source node to outside compilation nodes, and edges

View File

@ -19,7 +19,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow { namespace tensorflow {
@ -89,6 +91,15 @@ struct XlaClusterInfo {
const std::map<string, int> host_compute_core; const std::map<string, int> host_compute_core;
}; };
// Finds dependencies between outside compilation clusters, including both data
// dependencies and control dependencies. cluster_deps maps the name name of an
// outside compilation cluster to a set of names of outside compilation clusters
// that it depends on.
stream_executor::port::StatusOr<
std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name);
// Preprocesses edges within the same XLA cluster. It will perform the following // Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order: // operations in order:
// //

View File

@ -15,12 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/graph_to_functiondef.h"
@ -287,15 +289,20 @@ absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
return results; return results;
} }
string host_compute_node_name(const string& original_oc_name) {
return absl::StrCat("outside_compilation_", original_oc_name,
"_host_compute");
}
// Builds XlaHostCompute NodeDef from the outside compilation call node. // Builds XlaHostCompute NodeDef from the outside compilation call node.
xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef( xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
const Node* call_node, const std::map<string, int>& host_compute_core) { const Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
string original_oc_name; string original_oc_name;
TF_RETURN_IF_ERROR(GetNodeAttr( TF_RETURN_IF_ERROR(GetNodeAttr(
call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name)); call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
NodeDefBuilder host_compute_builder( NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"), "XlaHostCompute");
"XlaHostCompute");
// Copy all attributes. // Copy all attributes.
for (auto attr : call_node->attrs()) { for (auto attr : call_node->attrs()) {
@ -309,9 +316,25 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
host_compute_builder.Attr("tpu_core", core); host_compute_builder.Attr("tpu_core", core);
} }
// Set input tokens. // Set input tokens and other outside compilation clusters that current
host_compute_builder.Attr(kXlaTokenInputNodesAttrName, // cluster depends in `kXlaTokenArgNodeName`. This is needed because when
std::vector<string>{kXlaTokenArgNodeName}); // outside compilation subgraphs are encapsulated and moved to host graph,
// control/data edges between them will only be reflected in host graph.
// From XLA's perspective, two originally dependent clusters are no longer
// connected, which makes them look like they can be scheduled for execution
// in arbitrary order even though in fact they must be executed in order
// according to their host-side graph dependency. This can cause deadlock.
// Therefore, we hint XLA what the correct ordering of these clusters should
// be to avoid deadlocks.
std::vector<string> xla_token_input_nodes;
xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
auto cluster_deps_it = cluster_deps.find(original_oc_name);
if (cluster_deps_it != cluster_deps.end()) {
for (auto dep : cluster_deps_it->second) {
xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
}
}
host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
// Populate inputs. // Populate inputs.
std::vector<DataType> input_dtypes; std::vector<DataType> input_dtypes;
@ -371,7 +394,8 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
// If the function call node has no input/output edges, we will just remove it // If the function call node has no input/output edges, we will just remove it
// and not create a XlaHostCompute node. // and not create a XlaHostCompute node.
Status ReplaceOrRemoveOutsideCompilationCallNode( Status ReplaceOrRemoveOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core) { Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
// If the function call node has no input/output edges, just remove it. // If the function call node has no input/output edges, just remove it.
bool has_edge = false; bool has_edge = false;
for (auto e : call_node->in_edges()) { for (auto e : call_node->in_edges()) {
@ -393,8 +417,9 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
} }
// Build XlaHostCompute NodeDef. // Build XlaHostCompute NodeDef.
TF_ASSIGN_OR_RETURN(NodeDef node_def, TF_ASSIGN_OR_RETURN(
BuildXlaHostComputeNodeDef(call_node, host_compute_core)); NodeDef node_def,
BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
TF_ASSIGN_OR_RETURN(Node * host_compute_node, TF_ASSIGN_OR_RETURN(Node * host_compute_node,
ReplaceNode(g, call_node, node_def)); ReplaceNode(g, call_node, node_def));
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString(); VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
@ -1589,6 +1614,11 @@ Status ExtractOutsideCompilationForFunction(
// We cannot early return here, because we might have outside compilation in // We cannot early return here, because we might have outside compilation in
// If/While function body. // If/While function body.
// Find dependencies between outside compilation clusters.
TF_ASSIGN_OR_RETURN(auto cluster_deps,
OutsideCompilationClusterDependencies(
fbody->graph, outside_compilation_attr_name));
// Preprocess edges between different outside compilations. They will be // Preprocess edges between different outside compilations. They will be
// restored in `ConstructHostGraph()`. // restored in `ConstructHostGraph()`.
TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
@ -1643,7 +1673,7 @@ Status ExtractOutsideCompilationForFunction(
for (Node* n : outside_compilation_nodes) { for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core)); graph_out.get(), n, host_compute_core, *cluster_deps));
} }
// Handle nodes with associated functions. // Handle nodes with associated functions.
@ -1691,11 +1721,13 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name, const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name, const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g, const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified) {
if (VLOG_IS_ON(4)) { if (VLOG_IS_ON(4)) {
DumpGraphToFile("extract_outside_compilation_before", *g, fld); DumpGraphToFile("extract_outside_compilation_before", *g, fld);
} }
*modified = false;
auto node_name_index = g->BuildNodeNameIndex(); auto node_name_index = g->BuildNodeNameIndex();
for (auto& iter : clusters) { for (auto& iter : clusters) {
string xla_cluster_name = iter.first; string xla_cluster_name = iter.first;
@ -1711,6 +1743,7 @@ Status ExtractOutsideCompilation(
func_name_attrs, func_name_attrs.name(), host_graph_func_name, func_name_attrs, func_name_attrs.name(), host_graph_func_name,
host_compute_core, flr, fld, &shape_inference_graphs, host_compute_core, flr, fld, &shape_inference_graphs,
&has_outside_compilation)); &has_outside_compilation));
*modified |= has_outside_compilation;
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot"); string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
Node* pivot_node = node_name_index[pivot_name]; Node* pivot_node = node_name_index[pivot_name];

View File

@ -101,7 +101,8 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name, const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name, const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g, const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified);
} // namespace tensorflow } // namespace tensorflow

View File

@ -922,4 +922,145 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
} }
} }
TEST_F(ExtractOutsideCompilationForFunctionTest,
OutsideCompilationClusterDataDependency) {
// Build the XLA computation func.
// "const0"
// "identity0" = "const0" (outside compilation cluster "0")
// "identity1" = "identity0" (outside compilation cluster "1")
// "identity2" = "identity1"
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
<< std::endl;
auto node_name_image = g->BuildNodeNameIndex();
node_name_image["identity0"]->AddAttr("_oc", "0");
node_name_image["identity1"]->AddAttr("_oc", "1");
PartialTensorShape shape({2});
node_name_image["identity1"]->AddAttr(
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
protobuf::Map<string, tensorflow::AttrValue> attrs;
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationTest(
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
EXPECT_NE(host_compute_0, nullptr);
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
EXPECT_NE(host_compute_1, nullptr);
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
std::vector<string> token_input_nodes;
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
token_input_nodes.clear();
std::vector<string> expected_token_input_nodes_1(
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
}
TEST_F(ExtractOutsideCompilationForFunctionTest,
OutsideCompilationClusterControlDependency) {
// Build the XLA computation func.
// "const0"
// "identity0" = "const0" (outside compilation cluster "0")
// "identity1" = "const0" "^identity0" (outside compilation cluster "1",
// control depdent on cluster "0")
// "identity2" = "identity1"
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
Output identity1 = ops::Identity(
s.WithOpName("identity1").WithControlDependencies(identity0), const0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
<< std::endl;
auto node_name_image = g->BuildNodeNameIndex();
node_name_image["identity0"]->AddAttr("_oc", "0");
node_name_image["identity1"]->AddAttr("_oc", "1");
PartialTensorShape shape({2});
node_name_image["identity1"]->AddAttr(
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
protobuf::Map<string, tensorflow::AttrValue> attrs;
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationTest(
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
EXPECT_NE(host_compute_0, nullptr);
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
EXPECT_NE(host_compute_1, nullptr);
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
std::vector<string> token_input_nodes;
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
token_input_nodes.clear();
std::vector<string> expected_token_input_nodes_1(
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -36,6 +36,10 @@ std::once_flag flags_init;
bool SetterForXlaAutoJitFlag(const string& value) { bool SetterForXlaAutoJitFlag(const string& value) {
int32 opt_level; int32 opt_level;
// We need to use the mark_for_compilation_flags directly here instead of
// going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
// latter will try to setup and parse flags, which would bring us back to this
// setter.
if (absl::SimpleAtoi(value, &opt_level)) { if (absl::SimpleAtoi(value, &opt_level)) {
mark_for_compilation_flags->xla_auto_jit_flag mark_for_compilation_flags->xla_auto_jit_flag
.optimization_level_single_gpu = opt_level; .optimization_level_single_gpu = opt_level;
@ -155,9 +159,14 @@ void AllocateAndParseFlags() {
} // namespace } // namespace
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { bool SetXlaAutoJitFlagFromFlagString(const string& value) {
std::call_once(flags_init, &AllocateAndParseFlags); std::call_once(flags_init, &AllocateAndParseFlags);
return *build_ops_flags; return SetterForXlaAutoJitFlag(value);
}
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return build_ops_flags;
} }
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {

View File

@ -38,6 +38,12 @@ struct XlaAutoJitFlag {
int32 optimization_level_general; int32 optimization_level_general;
}; };
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
// is:
// <number>: sets general and single_gpu setting to the provided number.
// single-gpu(<number>): sets the single_gpu setting to the provided number.
bool SetXlaAutoJitFlagFromFlagString(const string& value);
// Flags associated with the XLA bridge's mark_for_compilation_pass module. // Flags associated with the XLA bridge's mark_for_compilation_pass module.
struct MarkForCompilationPassFlags { struct MarkForCompilationPassFlags {
XlaAutoJitFlag xla_auto_jit_flag; XlaAutoJitFlag xla_auto_jit_flag;
@ -111,7 +117,7 @@ struct IntroduceFloatingPointJitterPassFlags {
// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer // parses TF_XLA_FLAGS for all of them. Those functions which return a pointer
// always return the same pointer. // always return the same pointer.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
XlaDeviceFlags* GetXlaDeviceFlags(); XlaDeviceFlags* GetXlaDeviceFlags();
const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); const XlaOpsCommonFlags& GetXlaOpsCommonFlags();

View File

@ -13,8 +13,23 @@ cc_library(
srcs = ["graphcycles.cc"], srcs = ["graphcycles.cc"],
hdrs = ["graphcycles.h"], hdrs = ["graphcycles.h"],
deps = [ deps = [
":ordered_set",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "ordered_set",
hdrs = ["ordered_set.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
], ],
) )
@ -28,3 +43,14 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],
) )
tf_cc_test(
name = "ordered_set_test",
srcs = ["ordered_set_test.cc"],
deps = [
":ordered_set",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -34,14 +34,20 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <unordered_set> #include <unordered_set>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
typedef std::unordered_set<int32> NodeSet; using NodeSet = absl::flat_hash_set<int32>;
using OrderedNodeSet = OrderedSet<int32>;
template <typename T> template <typename T>
struct VecStruct { struct VecStruct {
typedef absl::InlinedVector<T, 4> type; typedef absl::InlinedVector<T, 4> type;
@ -50,13 +56,11 @@ template <typename T>
using Vec = typename VecStruct<T>::type; using Vec = typename VecStruct<T>::type;
struct Node { struct Node {
Node() : in(4), out(4) {} // Small hashtables for in/out edges
int32 rank; // rank number assigned by Pearce-Kelly algorithm int32 rank; // rank number assigned by Pearce-Kelly algorithm
bool visited; // Temporary marker used by depth-first-search bool visited; // Temporary marker used by depth-first-search
void* data; // User-supplied data void* data; // User-supplied data
NodeSet in; // List of immediate predecessor nodes in graph OrderedNodeSet in; // List of immediate predecessor nodes in graph
NodeSet out; // List of immediate successor nodes in graph OrderedNodeSet out; // List of immediate successor nodes in graph
}; };
} // namespace } // namespace
@ -93,7 +97,7 @@ bool GraphCycles::CheckInvariants() const {
if (!ranks.insert(nx->rank).second) { if (!ranks.insert(nx->rank).second) {
LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank; LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
} }
for (auto y : nx->out) { for (int32 y : nx->out.GetSequence()) {
Node* ny = r->nodes_[y]; Node* ny = r->nodes_[y];
if (nx->rank >= ny->rank) { if (nx->rank >= ny->rank) {
LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment " LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
@ -124,14 +128,14 @@ int32 GraphCycles::NewNode() {
void GraphCycles::RemoveNode(int32 node) { void GraphCycles::RemoveNode(int32 node) {
Node* x = rep_->nodes_[node]; Node* x = rep_->nodes_[node];
for (auto y : x->out) { for (int32 y : x->out.GetSequence()) {
rep_->nodes_[y]->in.erase(node); rep_->nodes_[y]->in.Erase(node);
} }
for (auto y : x->in) { for (int32 y : x->in.GetSequence()) {
rep_->nodes_[y]->out.erase(node); rep_->nodes_[y]->out.Erase(node);
} }
x->in.clear(); x->in.Clear();
x->out.clear(); x->out.Clear();
rep_->free_nodes_.push_back(node); rep_->free_nodes_.push_back(node);
} }
@ -144,12 +148,12 @@ void GraphCycles::SetNodeData(int32 node, void* data) {
} }
bool GraphCycles::HasEdge(int32 x, int32 y) const { bool GraphCycles::HasEdge(int32 x, int32 y) const {
return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end(); return rep_->nodes_[x]->out.Contains(y);
} }
void GraphCycles::RemoveEdge(int32 x, int32 y) { void GraphCycles::RemoveEdge(int32 x, int32 y) {
rep_->nodes_[x]->out.erase(y); rep_->nodes_[x]->out.Erase(y);
rep_->nodes_[y]->in.erase(x); rep_->nodes_[y]->in.Erase(x);
// No need to update the rank assignment since a previous valid // No need to update the rank assignment since a previous valid
// rank assignment remains valid after an edge deletion. // rank assignment remains valid after an edge deletion.
} }
@ -165,13 +169,13 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
if (x == y) return false; if (x == y) return false;
Rep* r = rep_; Rep* r = rep_;
Node* nx = r->nodes_[x]; Node* nx = r->nodes_[x];
if (!nx->out.insert(y).second) { if (!nx->out.Insert(y)) {
// Edge already exists. // Edge already exists.
return true; return true;
} }
Node* ny = r->nodes_[y]; Node* ny = r->nodes_[y];
ny->in.insert(x); ny->in.Insert(x);
if (nx->rank <= ny->rank) { if (nx->rank <= ny->rank) {
// New edge is consistent with existing rank assignment. // New edge is consistent with existing rank assignment.
@ -182,8 +186,8 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
// We only need to consider nodes that fall in the range [ny->rank,nx->rank]. // We only need to consider nodes that fall in the range [ny->rank,nx->rank].
if (!ForwardDFS(r, y, nx->rank)) { if (!ForwardDFS(r, y, nx->rank)) {
// Found a cycle. Undo the insertion and tell caller. // Found a cycle. Undo the insertion and tell caller.
nx->out.erase(y); nx->out.Erase(y);
ny->in.erase(x); ny->in.Erase(x);
// Since we do not call Reorder() on this path, clear any visited // Since we do not call Reorder() on this path, clear any visited
// markers left by ForwardDFS. // markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf_); ClearVisitedBits(r, r->deltaf_);
@ -209,7 +213,7 @@ static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
nn->visited = true; nn->visited = true;
r->deltaf_.push_back(n); r->deltaf_.push_back(n);
for (auto w : nn->out) { for (auto w : nn->out.GetSequence()) {
Node* nw = r->nodes_[w]; Node* nw = r->nodes_[w];
if (nw->rank == upper_bound) { if (nw->rank == upper_bound) {
return false; // Cycle return false; // Cycle
@ -235,7 +239,7 @@ static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
nn->visited = true; nn->visited = true;
r->deltab_.push_back(n); r->deltab_.push_back(n);
for (auto w : nn->in) { for (auto w : nn->in.GetSequence()) {
Node* nw = r->nodes_[w]; Node* nw = r->nodes_[w];
if (!nw->visited && lower_bound < nw->rank) { if (!nw->visited && lower_bound < nw->rank) {
r->stack_.push_back(w); r->stack_.push_back(w);
@ -321,7 +325,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
return path_len; return path_len;
} }
for (auto w : r->nodes_[n]->out) { for (auto w : r->nodes_[n]->out.GetSequence()) {
if (seen.insert(w).second) { if (seen.insert(w).second) {
r->stack_.push_back(w); r->stack_.push_back(w);
} }
@ -375,31 +379,94 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
} }
Node* nb = rep_->nodes_[b]; Node* nb = rep_->nodes_[b];
std::unordered_set<int32> out = std::move(nb->out); OrderedNodeSet out = std::move(nb->out);
std::unordered_set<int32> in = std::move(nb->in); OrderedNodeSet in = std::move(nb->in);
for (auto y : out) { for (int32 y : out.GetSequence()) {
rep_->nodes_[y]->in.erase(b); rep_->nodes_[y]->in.Erase(b);
} }
for (auto y : in) { for (int32 y : in.GetSequence()) {
rep_->nodes_[y]->out.erase(b); rep_->nodes_[y]->out.Erase(b);
} }
rep_->free_nodes_.push_back(b); rep_->free_nodes_.push_back(b);
for (auto y : out) { rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size());
for (int32 y : out.GetSequence()) {
InsertEdge(a, y); InsertEdge(a, y);
} }
for (auto y : in) {
rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size());
for (int32 y : in.GetSequence()) {
InsertEdge(y, a); InsertEdge(y, a);
} }
return true; return true;
} }
std::unordered_set<int32> GraphCycles::Successors(int32 node) const { absl::Span<const int32> GraphCycles::Successors(int32 node) const {
return rep_->nodes_[node]->out; return rep_->nodes_[node]->out.GetSequence();
} }
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) const { absl::Span<const int32> GraphCycles::Predecessors(int32 node) const {
return rep_->nodes_[node]->in; return rep_->nodes_[node]->in.GetSequence();
}
std::vector<int32> GraphCycles::SuccessorsCopy(int32 node) const {
absl::Span<const int32> successors = Successors(node);
return std::vector<int32>(successors.begin(), successors.end());
}
std::vector<int32> GraphCycles::PredecessorsCopy(int32 node) const {
absl::Span<const int32> predecessors = Predecessors(node);
return std::vector<int32>(predecessors.begin(), predecessors.end());
}
namespace {
void SortInPostOrder(absl::Span<Node* const> nodes,
std::vector<int32>* to_sort) {
absl::c_sort(*to_sort, [&](int32 a, int32 b) {
DCHECK(a == b || nodes[a]->rank != nodes[b]->rank);
return nodes[a]->rank > nodes[b]->rank;
});
}
} // namespace
std::vector<int32> GraphCycles::AllNodesInPostOrder() const {
absl::flat_hash_set<int32> free_nodes_set;
absl::c_copy(rep_->free_nodes_,
std::inserter(free_nodes_set, free_nodes_set.begin()));
std::vector<int32> all_nodes;
all_nodes.reserve(rep_->nodes_.size() - free_nodes_set.size());
for (int64 i = 0, e = rep_->nodes_.size(); i < e; i++) {
if (!free_nodes_set.contains(i)) {
all_nodes.push_back(i);
}
}
SortInPostOrder(rep_->nodes_, &all_nodes);
return all_nodes;
}
string GraphCycles::DebugString() const {
absl::flat_hash_set<int32> free_nodes_set;
for (int32 free_node : rep_->free_nodes_) {
free_nodes_set.insert(free_node);
}
string result = "digraph {\n";
for (int i = 0; i < rep_->nodes_.size(); i++) {
if (free_nodes_set.contains(i)) {
continue;
}
for (int32 succ : rep_->nodes_[i]->out.GetSequence()) {
absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n");
}
}
absl::StrAppend(&result, "}\n");
return result;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ #ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ #define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#include <vector>
// GraphCycles detects the introduction of a cycle into a directed // GraphCycles detects the introduction of a cycle into a directed
// graph that is being built up incrementally. // graph that is being built up incrementally.
// //
@ -38,8 +40,7 @@ limitations under the License.
// FindPath() is linear in the size of the graph. // FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space. // The current implementation uses O(|V|+|E|) space.
#include <unordered_set> #include "absl/types/span.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -117,8 +118,26 @@ class GraphCycles {
// Expensive: should only be called from graphcycles_test.cc. // Expensive: should only be called from graphcycles_test.cc.
bool CheckInvariants() const; bool CheckInvariants() const;
std::unordered_set<int32> Successors(int32 node) const; // Warning: Do not use these if iterating over the span and modifying the
std::unordered_set<int32> Predecessors(int32 node) const; // GraphCycles at the same time. Instead use SuccessorsCopy/PredecessorsCopy.
absl::Span<const int32> Successors(int32 node) const;
absl::Span<const int32> Predecessors(int32 node) const;
// Return a copy of the sucessors set. This is needed for code using the
// collection while modifying the GraphCycles.
std::vector<int32> SuccessorsCopy(int32 node) const;
// Return a copy of the predecessors set. This is needed for code using the
// collection while modifying the GraphCycles.
std::vector<int32> PredecessorsCopy(int32 node) const;
// Returns all nodes in post order.
//
// If there is a path from X to Y then X appears after Y in the
// returned vector.
std::vector<int32> AllNodesInPostOrder() const;
// Returns the graph in graphviz format.
string DebugString() const;
// ---------------------------------------------------- // ----------------------------------------------------
struct Rep; struct Rep;

View File

@ -0,0 +1,85 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
// This is a set data structure that provides a deterministic iteration order.
// The iteration order of elements only depends on the sequence of
// inserts/deletes, so as long as the inserts/deletes happen in the same
// sequence, the set will have the same iteration order.
//
// Assumes that T can be cheaply copied for simplicity.
template <typename T>
class OrderedSet {
public:
// Inserts `value` into the ordered set. Returns true if the value was not
// present in the set before the insertion.
bool Insert(T value) {
bool new_insertion =
value_to_index_.insert({value, value_sequence_.size()}).second;
if (new_insertion) {
value_sequence_.push_back(value);
}
return new_insertion;
}
// Removes `value` from the set. Assumes `value` is already present in the
// set.
void Erase(T value) {
auto it = value_to_index_.find(value);
DCHECK(it != value_to_index_.end());
// Since we don't want to move values around in `value_sequence_` we swap
// the value in the last position and with value to be deleted and then
// pop_back.
value_to_index_[value_sequence_.back()] = it->second;
std::swap(value_sequence_[it->second], value_sequence_.back());
value_sequence_.pop_back();
value_to_index_.erase(it);
}
void Reserve(size_t new_size) {
value_to_index_.reserve(new_size);
value_sequence_.reserve(new_size);
}
void Clear() {
value_to_index_.clear();
value_sequence_.clear();
}
bool Contains(T value) const { return value_to_index_.contains(value); }
size_t Size() const { return value_sequence_.size(); }
absl::Span<T const> GetSequence() const { return value_sequence_; }
private:
// The stable order that we maintain through insertions and deletions.
std::vector<T> value_sequence_;
// Maps values to their indices in `value_sequence_`.
absl::flat_hash_map<T, int> value_to_index_;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
TEST(OrderedSetTest, Insert) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
EXPECT_FALSE(ordered_set.Insert(100));
EXPECT_EQ(ordered_set.Size(), 3);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_TRUE(ordered_set.Contains(100));
EXPECT_TRUE(ordered_set.Contains(80));
EXPECT_FALSE(ordered_set.Contains(40));
std::array<int, 3> expected_sequence = {90, 100, 80};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
}
TEST(OrderedSetTest, Erase) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
ordered_set.Erase(100);
EXPECT_EQ(ordered_set.Size(), 2);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_TRUE(ordered_set.Contains(80));
std::array<int, 2> expected_sequence_0 = {90, 80};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_0);
ordered_set.Erase(80);
EXPECT_EQ(ordered_set.Size(), 1);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 1> expected_sequence_1 = {90};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_1);
ordered_set.Erase(90);
EXPECT_EQ(ordered_set.Size(), 0);
EXPECT_FALSE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 0> expected_sequence_2 = {};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_2);
}
TEST(OrderedSetTest, Clear) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
ordered_set.Clear();
EXPECT_EQ(ordered_set.Size(), 0);
EXPECT_FALSE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 0> expected_sequence = {};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
}
TEST(OrderedSetTest, LargeInsertions) {
const int kSize = 50 * 9000;
OrderedSet<int> ordered_set;
for (int i = 0; i < kSize; i++) {
EXPECT_TRUE(ordered_set.Insert(i + 500));
}
for (int i = 0; i < kSize; i++) {
EXPECT_EQ(ordered_set.GetSequence()[i], i + 500);
}
}
} // namespace
} // namespace tensorflow

View File

@ -62,7 +62,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
se::Platform::Id platform_id = nullptr; se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr;
std::unique_ptr<XlaAllocator> xla_allocator; std::unique_ptr<XlaAllocator> xla_allocator;
xla::DeviceMemoryAllocator* device_allocator = nullptr; se::DeviceMemoryAllocator* device_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) { if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId; platform_id = se::host::kHostPlatformId;

View File

@ -40,7 +40,7 @@ class XlaPlatformInfo {
se::Platform::Id platform_id, se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata, const XlaDevice::Metadata* xla_device_metadata,
std::unique_ptr<XlaAllocator> xla_allocator, std::unique_ptr<XlaAllocator> xla_allocator,
xla::DeviceMemoryAllocator* device_allocator) se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type), : device_type_(device_type),
platform_id_(platform_id), platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata), xla_device_metadata_(xla_device_metadata),
@ -55,7 +55,7 @@ class XlaPlatformInfo {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
} }
xla::DeviceMemoryAllocator* allocator() const { se::DeviceMemoryAllocator* allocator() const {
return device_allocator_ ? device_allocator_ : xla_allocator_.get(); return device_allocator_ ? device_allocator_ : xla_allocator_.get();
} }
DeviceType device_type() const { return device_type_; } DeviceType device_type() const { return device_type_; }
@ -86,7 +86,7 @@ class XlaPlatformInfo {
// then device_allocator_ is null and xla_allocator_ points to an appropriate // then device_allocator_ is null and xla_allocator_ points to an appropriate
// XlaAllocator instance. // XlaAllocator instance.
std::unique_ptr<XlaAllocator> xla_allocator_; std::unique_ptr<XlaAllocator> xla_allocator_;
xla::DeviceMemoryAllocator* device_allocator_; se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
}; };

File diff suppressed because it is too large Load Diff

View File

@ -270,11 +270,11 @@ TEST(XlaCompilationTest, FunctionCalls) {
auto clusters = GetClusters(*graph); auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size()); EXPECT_EQ(2, clusters.size());
EXPECT_FALSE(clusters["B"].empty()); EXPECT_FALSE(clusters["C"].empty());
EXPECT_EQ(clusters["B"], clusters["C"]); EXPECT_EQ(clusters["C"], clusters["E"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend()); EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("B") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend()); EXPECT_TRUE(clusters.find("D") == clusters.cend());
EXPECT_TRUE(clusters.find("E") == clusters.cend());
} }
TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) { TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
@ -332,31 +332,6 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
EXPECT_NE(clusters["A"], ""); EXPECT_NE(clusters["A"], "");
} }
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
// cluster. This is partially to work around b/26800664, and partially because
// we should probably prefer to compile metadata operators with their producers
// wherever possible, rather than their consumers.
TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
// While all of the following ops are notionally compilable, none is
// permitted
// to start a cluster. So nothing should be compiled.
Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
static Status GradForUnaryCwise(FunctionDef* g, static Status GradForUnaryCwise(FunctionDef* g,
std::vector<FunctionDefHelper::Node> nodes) { std::vector<FunctionDefHelper::Node> nodes) {
for (auto& n : nodes) { for (auto& n : nodes) {
@ -1137,6 +1112,45 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) {
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]); EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
} }
TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
// This is similar to the 'DontClusterMergingNodes' above, except
// MatMulCombined is placed on the CPU.
Scope root = Scope::NewRootScope().ExitOnError();
absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
Output combined =
ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
TF_ASSERT_OK(root.ToGraph(graph.get()));
for (Node* n : graph->nodes()) {
if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
n->set_assigned_device_name(string(xla_cpu_dev0));
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
n->set_assigned_device_name(string(xla_gpu_dev0));
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
n->set_assigned_device_name(string(xla_gpu_dev1));
}
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
// Each of the MatMuls should be in a separate cluster.
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
}
// TODO(b/117085735): This form of clustering should be prevented. // TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) { TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
// MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
@ -1534,5 +1548,59 @@ TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
EXPECT_EQ(clusters["test/z"], ""); EXPECT_EQ(clusters["test/z"], "");
} }
// Note, this relies on other implementation details to test the
// specific heuristic we care about here, so other changes might be at fault if
// this CL breaks. What we care about is that if a ShapeConsumingOp can be
// connected with a producer or consumer and cannot be clustered with both, it
// should be clustered with the producer.
TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
Output y = ops::Size(root.WithOpName("test/y"), x);
Output z = ops::Add(root.WithOpName("test/z"), y, y);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
// Ensure that the "Size" op can only be clustered with either the producer or
// consumer by putting them on different devices.
FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU1);
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["test/y"], "");
EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
EXPECT_NE(clusters["test/z"], clusters["test/y"]);
}
// Test that ShapeConsuming ops are still fully clustered whenever possible.
TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
Output y = ops::Size(root.WithOpName("test/y"), x);
Output z = ops::Add(root.WithOpName("test/z"), y, y);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["test/y"], "");
EXPECT_EQ(clusters["test/y"], clusters["test/x"]);
EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -14,9 +14,11 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -49,6 +51,15 @@ Status FindNodesToDecluster(const Graph& graph,
continue; continue;
} }
// Assume the benefit of not outputting a larger tensor outweighs the
// benefit of this check.
// TODO(tpopp): Only apply this if the value being consumed is not output
// from the cluster to another consumer.
// TODO(tpopp): See if XlaRun can be modified to avoid this issue
// completely.
if (IsShapeConsumerOp(*n)) {
continue;
}
// We assume the only XLA-auto-clusterable operations with side effects are // We assume the only XLA-auto-clusterable operations with side effects are
// resource variable updates. We can't execute these twice. // resource variable updates. We can't execute these twice.
if (HasResourceInputOrOutput(*n)) { if (HasResourceInputOrOutput(*n)) {
@ -57,7 +68,7 @@ Status FindNodesToDecluster(const Graph& graph,
DeviceType device_type(""); DeviceType device_type("");
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
DeviceToDeviceType(n->assigned_device_name(), &device_type)); DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
n->def(), &input_mtypes, n->def(), &input_mtypes,
&output_mtypes)); &output_mtypes));
@ -77,8 +88,8 @@ Status FindNodesToDecluster(const Graph& graph,
} else { } else {
MemoryTypeVector dst_input_mtypes, dst_output_mtypes; MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
DeviceType dst_device_type(""); DeviceType dst_device_type("");
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(DeviceNameToDeviceType(dst->assigned_device_name(),
DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type)); &dst_device_type));
TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
dst->def(), &dst_input_mtypes, dst->def(), &dst_input_mtypes,
&dst_output_mtypes)); &dst_output_mtypes));
@ -237,7 +248,7 @@ bool IsMustCompileDevice(const DeviceType& device_type) {
Status MustCompileNode(const Node* n, bool* must_compile) { Status MustCompileNode(const Node* n, bool* must_compile) {
DeviceType device_type(""); DeviceType device_type("");
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
DeviceToDeviceType(n->assigned_device_name(), &device_type)); DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
if (IsMustCompileDevice(device_type)) { if (IsMustCompileDevice(device_type)) {
*must_compile = true; *must_compile = true;
@ -340,6 +351,40 @@ Status PartiallyDeclusterGraph(Graph* graph,
return Status::OK(); return Status::OK();
} }
} // namespace reduce_recompilation } // namespace reduce_recompilation
namespace decluster_root_shape_consumers {
Status PartiallyDeclusterGraph(Graph* graph) {
std::vector<Node*> reverse_post_order;
GetReversePostOrder(*graph, &reverse_post_order,
/*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
for (Node* n : reverse_post_order) {
if (!IsShapeConsumerOp(*n)) {
continue;
}
absl::optional<absl::string_view> cluster = GetXlaClusterForNode(*n);
if (!cluster.has_value()) {
continue;
}
auto input_belongs_to_same_cluster = [&](const Edge* e) {
return cluster == GetXlaClusterForNode(*e->src());
};
if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) {
continue;
}
VLOG(2) << "Declustering " << n->name()
<< " because it is a root shape consumer";
RemoveFromXlaCluster(n);
}
return Status::OK();
}
} // namespace decluster_root_shape_consumers
} // namespace } // namespace
Status PartiallyDeclusterPass::Run( Status PartiallyDeclusterPass::Run(
@ -367,6 +412,9 @@ Status PartiallyDeclusterPass::Run(
TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph( TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(
graph, options.flib_def, options.session_options->env)); graph, options.flib_def, options.session_options->env));
TF_RETURN_IF_ERROR(
decluster_root_shape_consumers::PartiallyDeclusterGraph(graph));
return Status::OK(); return Status::OK();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -40,20 +40,20 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
REGISTER_OP("FakeNullary").Output("out: float"); REGISTER_OP("FakeNullary").Output("out: int32");
REGISTER_OP("FakeBinary") REGISTER_OP("FakeBinary")
.Input("host_in: float") .Input("host_in: int32")
.Input("device_in: float") .Input("device_in: int32")
.Output("host_out: float") .Output("host_out: int32")
.Output("device_out: float"); .Output("device_out: int32");
REGISTER_OP("FakeResourceVar").Output("out: resource"); REGISTER_OP("FakeResourceVar").Output("out: resource");
REGISTER_OP("FakeResourceUpdate") REGISTER_OP("FakeResourceUpdate")
.Input("in: resource") .Input("in: resource")
.Output("out: resource") .Output("out: resource")
.Output("something_else: float"); .Output("something_else: int32");
class FakeBinaryOp : public OpKernel { class FakeBinaryOp : public OpKernel {
public: public:
@ -467,5 +467,61 @@ TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr); EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr);
} }
TEST(PartiallyDeclusterPassTest, MetadataOpsDontStartClusters) {
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
Output b = ops::Shape(in_cluster_and.WithOpName("b"), a);
Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
Output d = ops::Size(in_cluster_and.WithOpName("d"), c);
(void)ops::Shape(in_cluster_and.WithOpName("e"), d);
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(PartiallyDecluster(&graph));
Node* n_b = FindNodeByName(*graph, "b");
ASSERT_NE(n_b, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_b), absl::nullopt);
Node* n_c = FindNodeByName(*graph, "c");
ASSERT_NE(n_c, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_c), absl::nullopt);
Node* n_d = FindNodeByName(*graph, "d");
ASSERT_NE(n_d, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_d), absl::nullopt);
Node* n_e = FindNodeByName(*graph, "e");
ASSERT_NE(n_e, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_e), absl::nullopt);
}
TEST(PartiallyDeclusterPassTest, MetaConsumersArentDeclustered) {
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
Output b = ops::Add(in_cluster_and.WithOpName("b"), a, a);
Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
Output e;
TF_ASSERT_OK(
CreateOutputWithScope("FakeBinary", {c, c}, root.WithOpName("e"), &e));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(PartiallyDecluster(&graph));
Node* n_b = FindNodeByName(*graph, "b");
ASSERT_NE(n_b, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_b), "cluster_0");
Node* n_c = FindNodeByName(*graph, "c");
ASSERT_NE(n_c, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,238 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
TEST(RearrangeFunctionArgumentForFunctionTest, Basic) {
FunctionDefLibrary fdl;
{
// Function for StatefulPartitionedCall's "f", If's
// "then_branch"/"else_branch".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
// "ret0" = "arg1"
// "ret1" = "arg0"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
}
{
// Function for While's "body".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
// "ret0" = "arg0"
// "ret1" = "arg1"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 0);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
}
{
// Function for While's "cond".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
// "ret0" = "arg1"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
// Build the XLA computation graph.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
// "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
// "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
NameAttrList f;
f.set_name("f1");
auto if_op = ops::If(s.WithOpName("if"), arg1,
std::initializer_list<Input>{arg0, arg1},
{DT_BOOL, DT_RESOURCE}, f, f);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f3");
body_fn.set_name("f2");
auto while_op =
ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::vector<std::unique_ptr<FunctionBody>> fbodies;
TF_CHECK_OK(RearrangeFunctionArguments(
[&](const NameAttrList &function, const FunctionBody **fbody) {
std::unique_ptr<FunctionBody> new_fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
AttrSlice(&function.attr()),
&fld, &new_fbody));
*fbody = new_fbody.get();
fbodies.push_back(std::move(new_fbody));
return Status::OK();
},
g.get(), &fld));
// Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE}
// and output types should be {DT_BOOL}.
const FunctionDef *f1_rewritten = fld.Find("f1_rearrange_0");
CHECK_NE(f1_rewritten, nullptr);
ASSERT_EQ(f1_rewritten->signature().input_arg_size(), 2);
EXPECT_EQ(f1_rewritten->signature().input_arg(0).type(), DT_BOOL);
EXPECT_EQ(f1_rewritten->signature().input_arg(1).type(), DT_RESOURCE);
ASSERT_EQ(f1_rewritten->signature().output_arg_size(), 1);
EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL);
// Check node "if" input and output edges.
auto node_name_index = g->BuildNodeNameIndex();
const Node *if_node = node_name_index.at("if");
ASSERT_NE(if_node, nullptr);
const Node *input_node;
TF_CHECK_OK(if_node->input_node(1, &input_node));
EXPECT_EQ(input_node->name(), "arg1");
TF_CHECK_OK(if_node->input_node(2, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret0_node = node_name_index.at("ret0");
ASSERT_NE(ret0_node, nullptr);
TF_CHECK_OK(ret0_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "if");
const Node *ret1_node = node_name_index.at("ret1");
ASSERT_NE(ret1_node, nullptr);
TF_CHECK_OK(ret1_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
// Check node "while" input and output edges.
const Node *while_node = node_name_index.at("while");
ASSERT_NE(while_node, nullptr);
TF_CHECK_OK(while_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg1");
TF_CHECK_OK(while_node->input_node(1, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret2_node = node_name_index.at("ret2");
ASSERT_NE(ret2_node, nullptr);
TF_CHECK_OK(ret2_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret3_node = node_name_index.at("ret3");
ASSERT_NE(ret3_node, nullptr);
TF_CHECK_OK(ret3_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "while");
}
TEST(RearrangeFunctionArgumentForFunctionTest,
WhileResourceRetvalFromDifferentArgUnimplemented) {
FunctionDefLibrary fdl;
{
// Function for While's "body".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "ret0" = "arg1"
// "ret1" = "arg0"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), arg2, 2);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
}
{
// Function for While's "cond".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "ret0" = true
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
Output cond = ops::Const(s.WithOpName("const"), true, TensorShape({}));
auto ret0 = ops::_Retval(s.WithOpName("ret0"), cond, 0);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
// Build the XLA computation graph.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "arg0", "arg1" -> "while" (While)
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f1");
body_fn.set_name("f2");
auto while_op = ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1, arg2},
cond_fn, body_fn);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::vector<std::unique_ptr<FunctionBody>> fbodies;
Status status = RearrangeFunctionArguments(
[&](const NameAttrList &function, const FunctionBody **fbody) {
std::unique_ptr<FunctionBody> new_fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
AttrSlice(&function.attr()),
&fld, &new_fbody));
*fbody = new_fbody.get();
fbodies.push_back(std::move(new_fbody));
return Status::OK();
},
g.get(), &fld);
EXPECT_EQ(status.code(), error::UNIMPLEMENTED);
}
} // namespace tensorflow

View File

@ -84,15 +84,6 @@ bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
} // namespace } // namespace
Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}
bool HasForwardedRefInput(const Node& node) { bool HasForwardedRefInput(const Node& node) {
if (AlwaysForwardsRefInput(node)) { if (AlwaysForwardsRefInput(node)) {
for (const Edge* incoming_edge : node.in_edges()) { for (const Edge* incoming_edge : node.in_edges()) {
@ -226,108 +217,6 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device,
string* out_device_picked) {
if (out_can_pick_device) {
*out_can_pick_device = true;
}
#define FAILED_TO_PICK_DEVICE(failing_status) \
do { \
if (out_can_pick_device) { \
*out_can_pick_device = false; \
return Status::OK(); \
} else { \
return failing_status; \
} \
} while (false)
TF_RET_CHECK(!device_names.empty()) << "No devices to choose from";
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
absl::flat_hash_set<absl::string_view> device_names_set;
for (absl::string_view device_name : device_names) {
if (!device_name.empty()) {
device_names_set.insert(device_name);
}
}
absl::optional<absl::string_view> maybe_gpu_device;
absl::optional<absl::string_view> maybe_cpu_device;
absl::optional<absl::string_view> maybe_unknown_device;
for (absl::string_view device_name : device_names_set) {
DeviceNameUtils::ParsedName parsed_name;
TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
<< device_name;
if (parsed_name.type == "GPU") {
if (maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple GPU devices ", absl::StrJoin(device_names, ", ")));
}
maybe_gpu_device = device_name;
} else if (parsed_name.type == "CPU") {
if (maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple CPU devices ", absl::StrJoin(device_names, ", ")));
}
maybe_cpu_device = device_name;
} else {
if (maybe_unknown_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple unknown devices ", absl::StrJoin(device_names, ", ")));
}
maybe_unknown_device = device_name;
}
}
if (maybe_unknown_device && maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and GPU devices: ", *maybe_unknown_device, ", ",
*maybe_gpu_device));
}
if (!allow_mixing_unknown_and_cpu) {
if (maybe_unknown_device && maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and CPU devices: ", *maybe_unknown_device, ", ",
*maybe_cpu_device));
}
}
if (out_device_picked) {
if (maybe_gpu_device) {
*out_device_picked = string(*maybe_gpu_device);
} else if (maybe_unknown_device) {
*out_device_picked = string(*maybe_unknown_device);
} else {
*out_device_picked = string(*maybe_cpu_device);
}
}
return Status::OK();
#undef FAILED_TO_PICK_DEVICE
}
Status PickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
string* out_device_picked) {
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
/*out_can_pick_device=*/nullptr,
out_device_picked);
}
Status CanPickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device) {
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
out_can_pick_device,
/*out_device_picked=*/nullptr);
}
namespace { namespace {
struct XlaGlobalJitLevel { struct XlaGlobalJitLevel {
OptimizerOptions::GlobalJitLevel single_gpu; OptimizerOptions::GlobalJitLevel single_gpu;
@ -425,4 +314,8 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
return name_attr_pair.second.has_func(); return name_attr_pair.second.has_func();
}); });
} }
bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
node.type_string() == "Size";
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -46,9 +46,6 @@ extern const char* const kXlaCompileTimeConstantInputsAttr;
using OrderedNodeSet = std::set<Node*, NodeComparatorID>; using OrderedNodeSet = std::set<Node*, NodeComparatorID>;
// Returns the DeviceType corresponding to 'device'.
Status DeviceToDeviceType(const string& device, DeviceType* device_type);
// Returns true if `node` has a ref tensor input that it forwards to its output. // Returns true if `node` has a ref tensor input that it forwards to its output.
bool HasForwardedRefInput(const Node& node); bool HasForwardedRefInput(const Node& node);
@ -74,51 +71,6 @@ void RemoveFromXlaCluster(Node* node);
// Returns true if `node` has a DT_RESOURCE typed input or output. // Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node); bool HasResourceInputOrOutput(const Node& node);
// Picks the device for which XLA should compile a cluster that contains
// operations placed in devices in `device_names`. For instance a cluster that
// contains operations solely placed on the CPU will be compiled into a CPU
// executable by XLA, whereas a cluster that contains operations placed on the
// CPU and also operations placed on the GPU will be compiled into a GPU
// executable.
//
// Returns a non-OK Status if no unambiguous choice of device exists.
//
// We choose the device using the following rules:
//
// - It is an error for `device_names` to contain more than one device of the
// same type.
// - GPU is preferred over CPU.
// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are
// preferred over CPU.
// - XLA devices count as "unrecognized devices".
//
// This set of rules above implicitly assume that XLA:GPU can compile all
// operations in the cluster that XLA:CPU can compile, and if
// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile
// all operations in the cluster that XLA:CPU can compile.
//
// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of
// the following things:
//
// - Let MarkForCompilationPass not inject CPU-placed operations into clusters
// that will run on unknown devices (because the unknown XLA backend may not
// support every operation supported by CPU).
// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster
// that contains nodes placed on both the CPU and on unknown devices. In this
// case it is the responsibility of the optimization pass that injected the
// CPU nodes into the cluster to ensure that these nodes can be compiled by
// the unknown XLA backend.
Status PickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
string* out_device_picked);
// This is like `PickDeviceForXla` except that it returns false (instead of a
// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device
// exists.
Status CanPickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device);
// Determines the global jit level based on GraphOptimizationPassOptions, // Determines the global jit level based on GraphOptimizationPassOptions,
// --tf_xla_auto_jit and whether the graph is a single GPU graph. // --tf_xla_auto_jit and whether the graph is a single GPU graph.
OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph( OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
@ -131,6 +83,10 @@ bool IsSingleGpuGraph(const Graph& g);
// Returns true if it is possible (but not guaranteed) that `n` calls a // Returns true if it is possible (but not guaranteed) that `n` calls a
// function. // function.
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def); bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
// Returns true if `node` an operator that consumes only the shape of its input,
// not the data itself.
bool IsShapeConsumerOp(const Node& node);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -91,67 +91,9 @@ TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
EXPECT_FALSE(ok); EXPECT_FALSE(ok);
} }
void CheckPickDeviceResult(absl::string_view expected_result,
bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) {
std::vector<string> inputs_string;
absl::c_transform(inputs, std::back_inserter(inputs_string),
[](absl::string_view sv) { return string(sv); });
string result;
TF_ASSERT_OK(
PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result))
<< "inputs = [" << absl::StrJoin(inputs, ", ")
<< "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu
<< ", expected_result=" << expected_result;
EXPECT_EQ(result, expected_result);
}
void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) {
std::vector<string> inputs_string;
absl::c_transform(inputs, std::back_inserter(inputs_string),
[](absl::string_view sv) { return string(sv); });
string result;
EXPECT_FALSE(
PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result)
.ok());
}
const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0"; const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0"; const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0";
const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1"; const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
TEST(PickDeviceForXla, UniqueDevice) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
}
TEST(PickDeviceForXla, DeviceOrder) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0});
}
TEST(PickDeviceForXla, MultipleUnknownDevices) {
CheckPickDeviceHasError(false, {kXPU0, kXPU1});
}
TEST(PickDeviceForXla, GpuAndUnknown) {
CheckPickDeviceHasError(false, {kGPU0, kXPU1});
}
TEST(PickDeviceForXla, UnknownAndCpu) {
CheckPickDeviceHasError(false, {kXPU0, kCPU1});
}
TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
CheckPickDeviceHasError(false, {kCPU0, kCPU1});
CheckPickDeviceHasError(false, {kGPU0, kGPU1});
CheckPickDeviceHasError(false, {kXPU0, kXPU1});
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
}
TEST(IsSingleGpuGraph, ReturnsTrue) { TEST(IsSingleGpuGraph, ReturnsTrue) {
Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError(); Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();

View File

@ -60,6 +60,7 @@ Status XlaCpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true; registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true; registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true; registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
static XlaDeviceOpRegistrations* registrations = static XlaDeviceOpRegistrations* registrations =

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include <stdlib.h> #include <stdlib.h>
#include <unordered_set> #include <unordered_set>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
@ -47,6 +48,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -380,14 +382,17 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) { AsyncOpKernel::DoneCallback done) {
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string(); << op_kernel->type_string();
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), profiler::TraceMe activity(
op_kernel->IsExpensive()); [&] {
return absl::StrCat(op_kernel->name(), ":", op_kernel->type_string());
},
profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
op_kernel->ComputeAsync(context, done); op_kernel->ComputeAsync(context, done);
} }
Status XlaDevice::Sync() { Status XlaDevice::Sync() {
VLOG(1) << "XlaDevice::Sync"; VLOG(1) << "XlaDevice::Sync";
tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true); profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
std::shared_ptr<se::Stream> stream; std::shared_ptr<se::Stream> stream;
{ {
mutex_lock lock(mu_); mutex_lock lock(mu_);
@ -428,13 +433,12 @@ void XlaDevice::Sync(const DoneCallback& done) {
// that everything enqueued onto the stream (i.e., the device) at this very // that everything enqueued onto the stream (i.e., the device) at this very
// moment--when ThenEnqueueOnBackgroundThread is called--will have finished. // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
// This achieves a device-wide sync. // This achieves a device-wide sync.
stream->ThenEnqueueOnBackgroundThread( stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
[stream, done](se::StreamExecutor*) { profiler::TraceMe activity("XlaDevice::Sync::Callback",
tracing::ScopedActivity activity("XlaDevice::Sync::Callback", profiler::TraceMeLevel::kInfo);
/*is_expensive=*/true); done(stream->ok() ? Status::OK()
done(stream->ok() ? Status::OK() : errors::Internal("XlaDevice::Sync() failed."));
: errors::Internal("XlaDevice::Sync() failed.")); });
});
} }
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
@ -458,11 +462,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Allocator* allocator = GetAllocatorLocked(alloc_attrs); Allocator* allocator = GetAllocatorLocked(alloc_attrs);
Tensor copy(allocator, parsed.dtype(), parsed.shape()); Tensor copy(allocator, parsed.dtype(), parsed.shape());
Notification n; Notification n;
device_context->CopyCPUTensorToDevice(&parsed, this, &copy, device_context->CopyCPUTensorToDevice(
[&n, &status](const Status& s) { &parsed, this, &copy,
status = s; [&n, &status](const Status& s) {
n.Notify(); status = s;
}); n.Notify();
},
true /*sync_dst_compute*/);
n.WaitForNotification(); n.WaitForNotification();
*tensor = copy; *tensor = copy;
} }

View File

@ -65,6 +65,9 @@ absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use; tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
tf_stats.largest_alloc_size = se_stats->largest_alloc_size; tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
tf_stats.bytes_limit = se_stats->bytes_limit; tf_stats.bytes_limit = se_stats->bytes_limit;
tf_stats.bytes_reserved = se_stats->bytes_reserved;
tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
return tf_stats; return tf_stats;
} }
@ -106,7 +109,8 @@ void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device, Device* device,
Tensor* device_tensor, Tensor* device_tensor,
StatusCallback done) const { StatusCallback done,
bool sync_dst_compute) const {
if (cpu_tensor->NumElements() == 0) { if (cpu_tensor->NumElements() == 0) {
VLOG(2) << "CopyCPUTensorToDevice empty tensor"; VLOG(2) << "CopyCPUTensorToDevice empty tensor";
done(Status::OK()); done(Status::OK());
@ -242,16 +246,25 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
cpu_tensor, &literal)); cpu_tensor, &literal));
TensorReference ref(*device_tensor); TensorReference ref(*device_tensor);
const bool device_allows_sync_on_completion =
device->AllowsSyncOnCompletion();
// Explicitly capture device_to_host_stream to make sure the stream is alive // Explicitly capture device_to_host_stream to make sure the stream is alive
// before the transfer finishes. // before the transfer finishes.
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal, device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
[ref, xla_tensor, done, device_to_host_stream](xla::Status status) { [ref, xla_tensor, done, device_to_host_stream,
done([&]() -> Status { device_allows_sync_on_completion](xla::Status status) {
VLOG(2) << "Transfer from device as literal: " Status done_status = status;
<< xla_tensor->shaped_buffer().ToString(); VLOG(2) << "Transfer from device as literal: "
return status; << xla_tensor->shaped_buffer().ToString();
}()); // For devices don't allow sync on completion, the device execution is
// deferred. We check the execution stream status here to avoid wrong
// results from a failed stream being propogated to following
// host-side ops.
if (!device_allows_sync_on_completion) {
done_status.Update(xla_tensor->RefreshStatusOfStreams());
}
done(done_status);
ref.Unref(); ref.Unref();
}); });
} }

View File

@ -61,8 +61,8 @@ class XlaDeviceContext : public DeviceContext {
thread::ThreadPool* thread_pool); thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, Tensor* device_tensor, StatusCallback done,
StatusCallback done) const override; bool sync_dst_compute) const override;
void CopyDeviceTensorToCPU(const Tensor* device_tensor, void CopyDeviceTensorToCPU(const Tensor* device_tensor,
absl::string_view tensor_name, Device* device, absl::string_view tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override; Tensor* cpu_tensor, StatusCallback done) override;

View File

@ -95,6 +95,7 @@ Status XlaGpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true; registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true; registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true; registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
static XlaDeviceOpRegistrations* registrations = static XlaDeviceOpRegistrations* registrations =

View File

@ -63,6 +63,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true; registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true; registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true; registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration); registration);

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/stream_executor_util.h" #include "tensorflow/core/util/stream_executor_util.h"
namespace tensorflow { namespace tensorflow {
@ -132,7 +133,8 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
// cluster because we would not handle variable updates correctly. Any // cluster because we would not handle variable updates correctly. Any
// locks we have already acquired will be released when the VariableInfo // locks we have already acquired will be released when the VariableInfo
// objects are destroyed. // objects are destroyed.
return errors::Internal("Duplicate variable passed to XLA cluster"); // TODO(b/128495870) Add support for passing aliased resource variables.
return errors::Unimplemented("Duplicate variable passed to XLA cluster");
} }
VLOG(4) << "Acquiring lock for variable " VLOG(4) << "Acquiring lock for variable "
<< reinterpret_cast<void*>(variable); << reinterpret_cast<void*>(variable);
@ -166,11 +168,11 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
} }
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
: xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} : se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
XlaAllocator::~XlaAllocator() {} XlaAllocator::~XlaAllocator() {}
xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate( xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) { int device_ordinal, uint64 size, bool retry_on_failure) {
AllocationAttributes attrs; AllocationAttributes attrs;
attrs.no_retry_on_failure = !retry_on_failure; attrs.no_retry_on_failure = !retry_on_failure;
@ -182,8 +184,8 @@ xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
"Out of memory while trying to allocate ", size, " bytes."); "Out of memory while trying to allocate ", size, " bytes.");
} }
} }
return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
device_ordinal, this); device_ordinal, this);
} }
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
@ -192,7 +194,7 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
} }
XlaComputationLaunchContext::XlaComputationLaunchContext( XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams) bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client), : client_(client),
xla_allocator_(xla_allocator), xla_allocator_(xla_allocator),
@ -242,7 +244,8 @@ void XlaComputationLaunchContext::PopulateInputs(
CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer()); arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
} else { } else {
CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
on_device_shape))
<< "On-device shape " << "On-device shape "
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape) << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
<< " not the same as on-host shape " << " not the same as on-host shape "
@ -371,7 +374,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
} else { } else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor( Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator); ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(xla::OwningDeviceMemory(), {output_num}); output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor); ctx->set_output(i, output_tensor);
} }
++output_num; ++output_num;
@ -432,7 +435,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
*variable_infos[i].var()->tensor() = output_tensor; *variable_infos[i].var()->tensor() = output_tensor;
} else { } else {
se::DeviceMemoryBase buffer = output.buffer({output_num}); se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(xla::OwningDeviceMemory(), {output_num}); output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = XlaTensorBuffer::MakeTensor( Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator); write.type, write.shape, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor; *variable_infos[i].var()->tensor() = output_tensor;

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace tensorflow { namespace tensorflow {
class XlaAllocator; class XlaAllocator;
@ -108,11 +107,11 @@ Status LockVariables(absl::Span<VariableInfo> variables)
// Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation: // Assumes that the Tensorflow allocator permits asynchronous deallocation:
// see comment on `AllowsAsynchronousDeallocation()`. // see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public xla::DeviceMemoryAllocator { class XlaAllocator : public se::DeviceMemoryAllocator {
public: public:
XlaAllocator(const se::Platform* platform, Allocator* wrapped); XlaAllocator(const se::Platform* platform, Allocator* wrapped);
~XlaAllocator() override; ~XlaAllocator() override;
xla::StatusOr<xla::OwningDeviceMemory> Allocate( xla::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override; int device_ordinal, uint64 size, bool retry_on_failure) override;
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
@ -129,6 +128,50 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
Allocator* wrapped_; Allocator* wrapped_;
}; };
// Adapter class that wraps per-device TF allocators as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation;
// see comment on `AllowsAsynchronousDeallocation()`.
class MultiDeviceAdapter : public se::DeviceMemoryAllocator {
public:
MultiDeviceAdapter(
const se::Platform* platform,
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators)
: DeviceMemoryAllocator(platform),
tf_allocators_(std::move(tf_allocators)) {
for (const auto& tf_allocator : tf_allocators_) {
per_device_allocators_.emplace_back(platform, tf_allocator.get());
}
}
xla::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override {
CHECK_LT(device_ordinal, per_device_allocators_.size());
return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size,
retry_on_failure);
}
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override {
CHECK_LT(device_ordinal, per_device_allocators_.size());
return per_device_allocators_[device_ordinal].Deallocate(device_ordinal,
mem);
}
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
// compute stream to enforce a happens-before relationship between a memory
// allocation and code that reuses the same memory. If Tensorflow adds
// support for multiple GPU streams or allocators with different ordering
// requirements, this code may need to change.
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
private:
std::vector<tensorflow::XlaAllocator> per_device_allocators_;
// The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does
// not take ownership of its underlying Allocator).
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators_;
};
// Helper class to perform the marshalling of TensorFlow inputs and outputs to // Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation. // ShapedBuffers suitable for passing to an XLA computation.
class XlaComputationLaunchContext { class XlaComputationLaunchContext {
@ -142,7 +185,7 @@ class XlaComputationLaunchContext {
// because we track inter-stream dependencies through events inside XlaTensor // because we track inter-stream dependencies through events inside XlaTensor
// objects. // objects.
XlaComputationLaunchContext(xla::LocalClient* client, XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool allocate_xla_tensors,
bool use_multiple_streams); bool use_multiple_streams);
@ -186,7 +229,7 @@ class XlaComputationLaunchContext {
private: private:
xla::LocalClient* client_; xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_; se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_; bool allocate_xla_tensors_;
bool use_multiple_streams_; bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_; std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;

View File

@ -59,11 +59,11 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype,
xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
uint64 size = uint64 size =
client->backend().transfer_manager()->GetByteSizeRequirement(subshape); client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
client->backend().memory_allocator()->Allocate( client->backend().memory_allocator()->Allocate(
device_ordinal, size, /*retry_on_failure=*/false)); device_ordinal, size, /*retry_on_failure=*/false));
// Move our buffer into shaped_buffer, which takes ownership of it. // Move our buffer into shaped_buffer, which takes ownership of it.
index_to_buffer.second = buffer.Forget(); index_to_buffer.second = buffer.Release();
} }
VLOG(4) << shaped_buffer.ToString(); VLOG(4) << shaped_buffer.ToString();
@ -97,6 +97,15 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
streams_defined_on_ = {stream}; streams_defined_on_ = {stream};
} }
Status XlaTensor::RefreshStatusOfStreams() {
mutex_lock lock(mu_);
Status status;
for (se::Stream* stream : streams_defined_on_) {
status.Update(stream->RefreshStatus());
}
return status;
}
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
// device-side tensors, which are either CPU or GPU memory pointers. This works // device-side tensors, which are either CPU or GPU memory pointers. This works
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.

View File

@ -102,6 +102,10 @@ class XlaTensor {
void ResetDefinitionEvent(std::shared_ptr<se::Event> event, void ResetDefinitionEvent(std::shared_ptr<se::Event> event,
se::Stream* stream); se::Stream* stream);
// Refresh the status of streams_defined_on_. Return the first not-OK stream's
// status or OK.
Status RefreshStatusOfStreams();
// Convert from a raw pointer to an XlaTensor, removing the pointer tag. // Convert from a raw pointer to an XlaTensor, removing the pointer tag.
static XlaTensor* FromOpaquePointer(void* ptr); static XlaTensor* FromOpaquePointer(void* ptr);
// Convert to a raw pointer from an XlaTensor, adding the pointer tag. // Convert to a raw pointer from an XlaTensor, adding the pointer tag.

View File

@ -65,6 +65,7 @@ py_test(
name = "xla_test_test", name = "xla_test_test",
size = "small", size = "small",
srcs = ["xla_test_test.py"], srcs = ["xla_test_test.py"],
python_version = "PY2",
deps = [ deps = [
":xla_test", ":xla_test",
], ],
@ -458,10 +459,6 @@ tf_xla_py_test(
name = "extract_image_patches_op_test", name = "extract_image_patches_op_test",
size = "small", size = "small",
srcs = ["extract_image_patches_op_test.py"], srcs = ["extract_image_patches_op_test.py"],
tags = [
"manual",
"notap",
],
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",

View File

@ -41,7 +41,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase):
all_lr = [1.0, 0.5, 0.1] all_lr = [1.0, 0.5, 0.1]
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
for grad in all_grad: for grad in all_grad:
for lr in all_lr: for lr in all_lr:
var0_init = [1.0, 2.0] var0_init = [1.0, 2.0]

View File

@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithoutRegularizationBasic1(self): def testAdagradDAWithoutRegularizationBasic1(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAwithoutRegularizationBasic2(self): def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1(self): def testAdagradDAWithL1(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1_L2(self): def testAdagradDAWithL1_L2(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)

View File

@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@ -59,7 +59,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self): def testTensorLearningRate(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@ -87,7 +87,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testSharing(self): def testSharing(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)

View File

@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements. # TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue continue
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
@ -99,7 +99,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements. # TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue continue
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
@ -142,7 +142,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements. # TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue continue
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.

Some files were not shown because too many files have changed in this diff Show More