Merge branch 'master' into master
This commit is contained in:
commit
2eb90433f9
@ -18,10 +18,11 @@ about: Use this template for reporting a bug or a performance issue.
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
|
||||
You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with
|
||||
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
|
52
.github/ISSUE_TEMPLATE/20-documentation-issue.md
vendored
52
.github/ISSUE_TEMPLATE/20-documentation-issue.md
vendored
@ -1,17 +1,55 @@
|
||||
---
|
||||
name: Documentation Issue
|
||||
about: Use this template for documentation related issues
|
||||
about: Use this template for documentation related
|
||||
labels: 'type:docs'
|
||||
|
||||
---
|
||||
|
||||
<em>Please make sure that this is a documentation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:doc_template</em>
|
||||
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
|
||||
policy, we only address code/doc bugs, performance issues, feature requests, and
|
||||
build/installation issues on GitHub.
|
||||
|
||||
The TensorFlow docs are open source! To get involved, read the documentation
|
||||
contributor guide: https://www.tensorflow.org/community/contribute/docs
|
||||
|
||||
**System information**
|
||||
- TensorFlow version:
|
||||
- Doc Link:
|
||||
## URL(s) with the issue:
|
||||
|
||||
Please provide a link to the documentation entry, for example:
|
||||
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/MyMethod
|
||||
|
||||
**Describe the documentation issue**
|
||||
## Description of issue (what needs changing):
|
||||
|
||||
**We welcome contributions by users. Will you be able to update submit a PR (use the [doc style guide](https://www.tensorflow.org/community/documentation)) to fix the doc Issue?**
|
||||
### Clear description
|
||||
|
||||
For example, why should someone use this method? How is it useful?
|
||||
|
||||
### Correct links
|
||||
|
||||
Is the link to the source code correct?
|
||||
|
||||
### Parameters defined
|
||||
|
||||
Are all parameters defined and formatted correctly?
|
||||
|
||||
### Returns defined
|
||||
|
||||
Are return values defined?
|
||||
|
||||
### Raises listed and defined
|
||||
|
||||
Are the errors defined? For example,
|
||||
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/feature_column/categorical_column_with_vocabulary_file#raises
|
||||
|
||||
### Usage example
|
||||
|
||||
Is there a usage example?
|
||||
|
||||
### Request visuals, if applicable
|
||||
|
||||
Are there currently visuals? If not, will it clarify the content?
|
||||
|
||||
### Submit a pull request?
|
||||
|
||||
Are you planning to also submit a pull request to fix the issue? See the docs
|
||||
contributor guide: https://www.tensorflow.org/community/contribute/docs and the
|
||||
docs style guide: https://www.tensorflow.org/community/contribute/docs_style
|
||||
|
21
.gitignore
vendored
21
.gitignore
vendored
@ -20,18 +20,8 @@ tensorflow/contrib/cmake/_build/
|
||||
[Bb]uild/
|
||||
/tensorflow/core/util/version_info.cc
|
||||
/tensorflow/python/framework/fast_tensor_util.cpp
|
||||
Pods
|
||||
Podfile.lock
|
||||
*.pbxproj
|
||||
*.xcworkspacedata
|
||||
/*.podspec
|
||||
/tensorflow/lite/experimental/objc/BUILD
|
||||
/tensorflow/lite/experimental/swift/BUILD
|
||||
/tensorflow/lite/examples/ios/simple/data/*.txt
|
||||
/tensorflow/lite/examples/ios/simple/data/*.tflite
|
||||
/tensorflow/lite/gen/**
|
||||
/tensorflow/lite/tools/make/downloads/**
|
||||
xcuserdata/**
|
||||
/api_init_files_list.txt
|
||||
/estimator_api_init_files_list.txt
|
||||
*.whl
|
||||
@ -42,3 +32,14 @@ xcuserdata/**
|
||||
*.iml
|
||||
local.properties
|
||||
gradleBuild
|
||||
|
||||
# iOS
|
||||
*.pbxproj
|
||||
*.xcworkspace
|
||||
/*.podspec
|
||||
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
|
||||
/tensorflow/lite/examples/ios/simple/data/*.tflite
|
||||
/tensorflow/lite/examples/ios/simple/data/*.txt
|
||||
Podfile.lock
|
||||
Pods
|
||||
xcuserdata
|
||||
|
21
README.md
21
README.md
@ -85,7 +85,7 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's
|
||||
uphold this code.**
|
||||
|
||||
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
|
||||
tracking requests and bugs, so please see
|
||||
tracking requests and bugs, please see
|
||||
[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss)
|
||||
for general questions and discussion, and please direct specific questions to
|
||||
[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
|
||||
@ -114,15 +114,16 @@ The TensorFlow project strives to abide by generally accepted best practices in
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
-------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
|
||||
**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
|
||||
**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
|
||||
**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
||||
**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
|
||||
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
|
||||
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5 and 3.6** | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/)
|
||||
Build Type | Status | Artifacts
|
||||
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
|
||||
**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
|
||||
**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
|
||||
**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
||||
**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
|
||||
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
|
||||
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, and 3.6** | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/)
|
||||
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 pypi](https://tensorflow.pypi.thoth-station.ninja/index/)
|
||||
|
||||
## For more information
|
||||
|
||||
|
278
RELEASE.md
278
RELEASE.md
@ -1,3 +1,10 @@
|
||||
# Release 1.12.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* Fixes a potential security vulnerability where carefully crafted GIF images
|
||||
can produce a null pointer dereference during decoding.
|
||||
|
||||
# Release 1.13.0
|
||||
|
||||
## Major Features and Improvements
|
||||
@ -14,98 +21,185 @@
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* Documentation
|
||||
* Update the doc with the details about the rounding mode used in quantize_and_dequantize_v2.
|
||||
* Clarify that tensorflow::port::InitMain() _should_ be called before using the TensorFlow library. Programs failing to do this are not portable to all platforms.
|
||||
* Deprecations and Symbol renames.
|
||||
* Removing deprecations for the following endpoints: `tf.acos`, `tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`, `tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`, `tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`, `tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`, `tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`, `tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan`
|
||||
* Deprecate `tf.data.Dataset.shard`.
|
||||
* Deprecate `saved_model.loader.load` which is replaced by `saved_model.load` and `saved_model.main_op`, which will be replaced by `saved_model.main_op` in V2.
|
||||
* Deprecate tf.QUANTIZED_DTYPES. The official new symbol is tf.dtypes.QUANTIZED_DTYPES.
|
||||
* Update sklearn imports for deprecated packages.
|
||||
* Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of `Dataset.range`.
|
||||
* Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of `tf.train.confusion_matrix`.
|
||||
* Add `tf.dtypes.` endpoint for every constant in dtypes.py; moving endpoints in versions.py to corresponding endpoints in `tf.sysconfig.` and `tf.version.`; moving all constants under `tf.saved_model` submodules to `tf.saved_model` module. New endpoints are added in V1 and V2 but existing endpoint removals are only applied in V2.
|
||||
* Deprecates behavior where device assignment overrides collocation constraints inside a collocation context manager.
|
||||
* Keras & Python API
|
||||
* Add to Keras functionality analogous to `tf.register_tensor_conversion_function`.
|
||||
* Subclassed Keras models can now be saved through `tf.contrib.saved_model.save_keras_model`.
|
||||
* `LinearOperator.matmul` now returns a new `LinearOperator`.
|
||||
* New ops and improved op functionality
|
||||
* Add a Nearest Neighbor Resize op.
|
||||
* Add an `ignore_unknown` argument to `parse_values` which suppresses ValueError for unknown hyperparameter types. Such * Add `tf.linalg.matvec` convenience function.
|
||||
* `tf.einsum()`raises `ValueError` for unsupported equations like `"ii->"`.
|
||||
* Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`.
|
||||
* Add LU decomposition op.
|
||||
* Add quantile loss to gradient boosted trees in estimator.
|
||||
* Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding algorithm.
|
||||
* Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`, `unicode_split`, `unicode_split_with_offset`, and `unicode_transcode` ops. Amongst other things, this Op adds the ability to encode, decode, and transcode a variety of input text encoding formats into the main Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE)
|
||||
* Add "unit" attribute to the substr op, which allows obtaining the substring of a string containing unicode characters.
|
||||
* Broadcasting support for Ragged Tensors.
|
||||
* `SpaceToDepth` supports uint8 data type.
|
||||
* Support multi-label quantile regression in estimator.
|
||||
* We now use "div" as the default partition_strategy in `tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and `tf.nn.nce_loss`.
|
||||
hyperparameter are ignored.
|
||||
* Performance
|
||||
* Improve performance of GPU cumsum/cumprod by up to 300x.
|
||||
* Added support for weight decay in most TPU embedding optimizers, including AdamW and MomentumW.
|
||||
* TensorFlow 2.0 Development
|
||||
* Add a command line tool to convert to TF2.0, tf_upgrade_v2
|
||||
* Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0.
|
||||
* Change the default recurrent activation function for LSTM from 'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is 'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we change the default for CPU mode to sigmoid as well. With that, the default LSTM will be compatible with both CPU and GPU kernel. This will enable user with GPU to use CuDNN kernel by default and get a 10x performance boost in training. Note that this is checkpoint breaking change. If user want to use their 1.x pre-trained checkpoint, please construct the layer with LSTM(recurrent_activation='hard_sigmoid') to fallback to 1.x behavior.
|
||||
* TensorFlow Lite
|
||||
* Move from `tensorflow/contrib/lite` to `tensorflow/lite`.
|
||||
* Add experimental Java API for injecting TensorFlow Lite delegates
|
||||
* Add support for strings in TensorFlow Lite Java API.
|
||||
* `tf.contrib`:
|
||||
* Add Apache Ignite Filesystem plugin to support accessing Apache IGFS.
|
||||
* Dropout now takes `rate` argument, `keep_prob` is deprecated.
|
||||
* Estimator occurrences references `tf.contrib.estimator` were changed to `tf.estimator`:
|
||||
* `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator`
|
||||
* `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator`
|
||||
* `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator`
|
||||
* `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator`
|
||||
* `tf.contrib.estimator.InMemoryEvaluatorHook` and tf.estimator.experimental.InMemoryEvaluatorHook`.
|
||||
* `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.make_stop_at_checkpoint_step_hook`.
|
||||
* Expose `tf.distribute.Strategy as the new name for tf.contrib.distribute.DistributionStrategy.
|
||||
* Migrate linear optimizer from contrib to core.
|
||||
* Move `tf.contrib.signal` to `tf.signal` (preserving aliases in tf.contrib.signal).
|
||||
* Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`.
|
||||
* tf.data:
|
||||
* Add `tf.data.experimental.StatsOptions()`, to configure options to collect statistics from `tf.data.Dataset` pipeline using `StatsAggregator`. Add nested option, `experimental_stats` (which takes a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`. Deprecates `tf.data.experimental.set_stats_agregator`.
|
||||
* Performance optimizations:
|
||||
* Add `tf.data.experimental.OptimizationOptions()`, to configure options to enable `tf.data` performance optimizations. Add nested option, `experimental_optimization` (which takes a `tf.data.experimental.OptimizationOptions` object), to `tf.data.Options`. Remove performance optimization options from `tf.data.Options`, and add them under `tf.data.experimental.OptimizationOptions` instead.
|
||||
* Enable `map_and_batch_fusion` and `noop_elimination` optimizations by default. They can be disabled by configuring `tf.data.experimental.OptimizationOptions` to set `map_and_batch = False` or `noop_elimination = False` respectively. To disable all default optimizations, set `apply_default_optimizations = False`.
|
||||
* Support parallel map in `map_and_filter_fusion`.
|
||||
* Disable static optimizations for input pipelines that use non-resource `tf.Variable`s.
|
||||
* Add NUMA-aware MapAndBatch dataset.
|
||||
* Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it from V2, and added tf.compat.v1.data.make_one_shot_iterator()`.
|
||||
* Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`.
|
||||
* Enable nested dataset support in core `tf.data` transformations.
|
||||
* For `tf.data.Dataset` implementers: Added `tf.data.Dataset._element_structured property` to replace `Dataset.output_{types,shapes,classes}`.
|
||||
* Make `num_parallel_calls` of `tf.data.Dataset.interleave` and `tf.data.Dataset.map` work in Eager mode.
|
||||
* Toolchains
|
||||
* Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`.
|
||||
* Added bounds checking to printing deprecation warnings.
|
||||
* Upgraded CUDA dependency to 10.0
|
||||
* To build with Android NDK r14b, add "#include <linux/compiler.h>" to android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h
|
||||
* Removed `:android_tensorflow_lib_selective_registration*` targets, use `:android_tensorflow_lib_lite*` targets instead.
|
||||
* XLA
|
||||
* Move `RoundToEven` function to xla/client/lib/math.h.
|
||||
* A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1" or "true" allows the debug options passed within an XRTCompile op to be passed directly to the XLA compilation backend. If such variable is not set (service side), only a restricted set will be passed through.
|
||||
* Allow the XRTCompile op to return the ProgramShape resulted form the XLA compilation as a second return argument.
|
||||
* XLA HLO graphs can now be rendered as SVG/HTML.
|
||||
* Estimator
|
||||
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator`
|
||||
* Replace all occurences of `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator`
|
||||
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator`
|
||||
* Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator`
|
||||
* Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`.
|
||||
* Update `regression_head` to the new Head API for Canned Estimator V2.
|
||||
* Switch `multi_class_head` to Head API for Canned Estimator V2.
|
||||
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook` and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.InMemoryEvaluatorHook` and `tf.estimator.experimental.make_stop_at_checkpoint_step_hook`
|
||||
* Migrate linear optimizer from contrib to core.
|
||||
|
||||
* Documentation
|
||||
* Update the doc with the details about the rounding mode used in
|
||||
quantize_and_dequantize_v2.
|
||||
* Clarify that tensorflow::port::InitMain() _should_ be called before
|
||||
using the TensorFlow library. Programs failing to do this are not
|
||||
portable to all platforms.
|
||||
* Deprecations and Symbol renames.
|
||||
* Removing deprecations for the following endpoints: `tf.acos`,
|
||||
`tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`,
|
||||
`tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`,
|
||||
`tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`,
|
||||
`tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`,
|
||||
`tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`,
|
||||
`tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan`
|
||||
* Deprecate `tf.data.Dataset.shard`.
|
||||
* Deprecate `saved_model.loader.load` which is replaced by
|
||||
`saved_model.load` and `saved_model.main_op`, which will be replaced by
|
||||
`saved_model.main_op` in V2.
|
||||
* Deprecate tf.QUANTIZED_DTYPES. The official new symbol is
|
||||
tf.dtypes.QUANTIZED_DTYPES.
|
||||
* Update sklearn imports for deprecated packages.
|
||||
* Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of
|
||||
`Dataset.range`.
|
||||
* Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of
|
||||
`tf.train.confusion_matrix`.
|
||||
* Add `tf.dtypes.` endpoint for every constant in dtypes.py. Moving
|
||||
endpoints in versions.py to corresponding endpoints in `tf.sysconfig.`
|
||||
and `tf.version.`. Moving all constants under `tf.saved_model`
|
||||
submodules to `tf.saved_model` module. New endpoints are added in V1 and
|
||||
V2 but existing endpoint removals are only applied in V2.
|
||||
* Deprecates behavior where device assignment overrides collocation
|
||||
constraints inside a collocation context manager.
|
||||
* Keras & Python API
|
||||
* Add to Keras functionality analogous to
|
||||
`tf.register_tensor_conversion_function`.
|
||||
* Subclassed Keras models can now be saved through
|
||||
`tf.contrib.saved_model.save_keras_model`.
|
||||
* `LinearOperator.matmul` now returns a new `LinearOperator`.
|
||||
* New ops and improved op functionality
|
||||
* Add a Nearest Neighbor Resize op.
|
||||
* Add an `ignore_unknown` argument to `parse_values` which suppresses
|
||||
ValueError for unknown hyperparameter types. Such * Add
|
||||
`tf.linalg.matvec` convenience function.
|
||||
* `tf.einsum()`raises `ValueError` for unsupported equations like
|
||||
`"ii->"`.
|
||||
* Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`.
|
||||
* Add LU decomposition op.
|
||||
* Add quantile loss to gradient boosted trees in estimator.
|
||||
* Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding
|
||||
algorithm.
|
||||
* Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`,
|
||||
`unicode_split`, `unicode_split_with_offset`, and `unicode_transcode`
|
||||
ops. Amongst other things, this Op adds the ability to encode, decode,
|
||||
and transcode a variety of input text encoding formats into the main
|
||||
Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE)
|
||||
* Add "unit" attribute to the substr op, which allows obtaining the
|
||||
substring of a string containing unicode characters.
|
||||
* Broadcasting support for Ragged Tensors.
|
||||
* `SpaceToDepth` supports uint8 data type.
|
||||
* Support multi-label quantile regression in estimator.
|
||||
* We now use "div" as the default partition_strategy in
|
||||
`tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and
|
||||
`tf.nn.nce_loss`. hyperparameter are ignored.
|
||||
* Performance
|
||||
* Improve performance of GPU cumsum/cumprod by up to 300x.
|
||||
* Added support for weight decay in most TPU embedding optimizers,
|
||||
including AdamW and MomentumW.
|
||||
* TensorFlow 2.0 Development
|
||||
* Add a command line tool to convert to TF2.0, tf_upgrade_v2
|
||||
* Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0.
|
||||
* Change the default recurrent activation function for LSTM from
|
||||
'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is
|
||||
'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend
|
||||
between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we
|
||||
change the default for CPU mode to sigmoid as well. With that, the
|
||||
default LSTM will be compatible with both CPU and GPU kernel. This will
|
||||
enable user with GPU to use CuDNN kernel by default and get a 10x
|
||||
performance boost in training. Note that this is checkpoint breaking
|
||||
change. If user want to use their 1.x pre-trained checkpoint, please
|
||||
construct the layer with LSTM(recurrent_activation='hard_sigmoid') to
|
||||
fallback to 1.x behavior.
|
||||
* TensorFlow Lite
|
||||
* Move from `tensorflow/contrib/lite` to `tensorflow/lite`.
|
||||
* Add experimental Java API for injecting TensorFlow Lite delegates
|
||||
* Add support for strings in TensorFlow Lite Java API.
|
||||
* `tf.contrib`:
|
||||
* Add Apache Ignite Filesystem plugin to support accessing Apache IGFS.
|
||||
* Dropout now takes `rate` argument, `keep_prob` is deprecated.
|
||||
* Estimator occurrences references `tf.contrib.estimator` were changed to
|
||||
`tf.estimator`:
|
||||
* `tf.contrib.estimator.BaselineEstimator` with
|
||||
`tf.estimator.BaselineEstimator`
|
||||
* `tf.contrib.estimator.DNNLinearCombinedEstimator` with
|
||||
`tf.estimator.DNNLinearCombinedEstimator`
|
||||
* `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator`
|
||||
* `tf.contrib.estimator.LinearEstimator` with
|
||||
`tf.estimator.LinearEstimator`
|
||||
* `tf.contrib.estimator.InMemoryEvaluatorHook` and
|
||||
tf.estimator.experimental.InMemoryEvaluatorHook`.
|
||||
* `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
|
||||
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`.
|
||||
* Expose `tf.distribute.Strategy as the new name for
|
||||
tf.contrib.distribute.DistributionStrategy.
|
||||
* Migrate linear optimizer from contrib to core.
|
||||
* Move `tf.contrib.signal` to `tf.signal` (preserving aliases in
|
||||
tf.contrib.signal).
|
||||
* Users of `tf.contrib.estimator.export_all_saved_models` and related
|
||||
should switch to
|
||||
`tf.estimator.Estimator.experimental_export_all_saved_models`.
|
||||
* tf.data:
|
||||
* Add `tf.data.experimental.StatsOptions()`, to configure options to
|
||||
collect statistics from `tf.data.Dataset` pipeline using
|
||||
`StatsAggregator`. Add nested option, `experimental_stats` (which takes
|
||||
a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`.
|
||||
Deprecates `tf.data.experimental.set_stats_agregator`.
|
||||
* Performance optimizations:
|
||||
* Add `tf.data.experimental.OptimizationOptions()`, to configure options
|
||||
to enable `tf.data` performance optimizations. Add nested option,
|
||||
`experimental_optimization` (which takes a
|
||||
`tf.data.experimental.OptimizationOptions` object), to
|
||||
`tf.data.Options`. Remove performance optimization options from
|
||||
`tf.data.Options`, and add them under
|
||||
`tf.data.experimental.OptimizationOptions` instead.
|
||||
* Enable `map_and_batch_fusion` and `noop_elimination` optimizations by
|
||||
default. They can be disabled by configuring
|
||||
`tf.data.experimental.OptimizationOptions` to set `map_and_batch =
|
||||
False` or `noop_elimination = False` respectively. To disable all
|
||||
default optimizations, set `apply_default_optimizations = False`.
|
||||
* Support parallel map in `map_and_filter_fusion`.
|
||||
* Disable static optimizations for input pipelines that use non-resource
|
||||
`tf.Variable`s.
|
||||
* Add NUMA-aware MapAndBatch dataset.
|
||||
* Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it
|
||||
from V2, and added tf.compat.v1.data.make_one_shot_iterator()`.
|
||||
* Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed
|
||||
it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`.
|
||||
* Enable nested dataset support in core `tf.data` transformations.
|
||||
* For `tf.data.Dataset` implementers: Added
|
||||
`tf.data.Dataset._element_structured property` to replace
|
||||
`Dataset.output_{types,shapes,classes}`.
|
||||
* Make `num_parallel_calls` of `tf.data.Dataset.interleave` and
|
||||
`tf.data.Dataset.map` work in Eager mode.
|
||||
* Toolchains
|
||||
* Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`.
|
||||
* Added bounds checking to printing deprecation warnings.
|
||||
* Upgraded CUDA dependency to 10.0
|
||||
* To build with Android NDK r14b, add "#include <linux/compiler.h>" to
|
||||
android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h
|
||||
* Removed `:android_tensorflow_lib_selective_registration*` targets, use
|
||||
`:android_tensorflow_lib_lite*` targets instead.
|
||||
* XLA
|
||||
* Move `RoundToEven` function to xla/client/lib/math.h.
|
||||
* A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1"
|
||||
or "true" allows the debug options passed within an XRTCompile op to be
|
||||
passed directly to the XLA compilation backend. If such variable is not
|
||||
set (service side), only a restricted set will be passed through.
|
||||
* Allow the XRTCompile op to return the ProgramShape resulted form the XLA
|
||||
compilation as a second return argument.
|
||||
* XLA HLO graphs can now be rendered as SVG/HTML.
|
||||
* Estimator
|
||||
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with
|
||||
`tf.estimator.BaselineEstimator`
|
||||
* Replace all occurences of
|
||||
`tf.contrib.estimator.DNNLinearCombinedEstimator` with
|
||||
`tf.estimator.DNNLinearCombinedEstimator`
|
||||
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with
|
||||
`tf.estimator.DNNEstimator`
|
||||
* Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with
|
||||
`tf.estimator.LinearEstimator`
|
||||
* Users of `tf.contrib.estimator.export_all_saved_models` and related
|
||||
should switch to
|
||||
`tf.estimator.Estimator.experimental_export_all_saved_models`.
|
||||
* Update `regression_head` to the new Head API for Canned Estimator V2.
|
||||
* Switch `multi_class_head` to Head API for Canned Estimator V2.
|
||||
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook`
|
||||
and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
|
||||
`tf.estimator.experimental.InMemoryEvaluatorHook` and
|
||||
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`
|
||||
* Migrate linear optimizer from contrib to core.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
12
WORKSPACE
12
WORKSPACE
@ -43,8 +43,8 @@ remote_config_workspace()
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "8f32e2839fba28d549e1670dbed83606dd339a9f7489118e481814d61738270f",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.14.0/rules_apple.0.14.0.tar.gz"],
|
||||
sha256 = "23792cd999f97fc97284d1c44cb1324bfdd0bc54aa68ad513fa3705aca3b1f9e",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.15.0/rules_apple.0.15.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
@ -58,14 +58,14 @@ http_archive(
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "31aad005a9c4e56b256125844ad05eb27c88303502d74138186f9083479f93a6",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.8.0/rules_swift.0.8.0.tar.gz"],
|
||||
sha256 = "9efe9699e9765e6b4a5e063e4a08f6b163cccaf0443f775d935baf5c3cd6ed0e",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.9.0/rules_swift.0.9.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
strip_prefix = "swift-protobuf-1.4.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.4.0.zip"],
|
||||
strip_prefix = "swift-protobuf-1.5.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"],
|
||||
) # https://github.com/apple/swift-protobuf/releases
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
|
49
configure.py
49
configure.py
@ -293,9 +293,9 @@ def get_var(environ_cp,
|
||||
|
||||
Args:
|
||||
environ_cp: copy of the os.environ.
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
|
||||
query_item: string for feature related to the variable, e.g. "Hadoop File
|
||||
System".
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
|
||||
query_item: string for feature related to the variable, e.g. "CUDA for
|
||||
Nvidia GPUs".
|
||||
enabled_by_default: boolean for default behavior.
|
||||
question: optional string for how to ask for user input.
|
||||
yes_reply: optional string for reply when feature is enabled.
|
||||
@ -376,9 +376,9 @@ def set_build_var(environ_cp,
|
||||
|
||||
Args:
|
||||
environ_cp: copy of the os.environ.
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
|
||||
query_item: string for feature related to the variable, e.g. "Hadoop File
|
||||
System".
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
|
||||
query_item: string for feature related to the variable, e.g. "CUDA for
|
||||
Nvidia GPUs".
|
||||
option_name: string for option to define in .bazelrc.
|
||||
enabled_by_default: boolean for default behavior.
|
||||
bazel_config_name: Name for Bazel --config argument to enable build feature.
|
||||
@ -411,9 +411,9 @@ def set_action_env_var(environ_cp,
|
||||
|
||||
Args:
|
||||
environ_cp: copy of the os.environ.
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
|
||||
query_item: string for feature related to the variable, e.g. "Hadoop File
|
||||
System".
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
|
||||
query_item: string for feature related to the variable, e.g. "CUDA for
|
||||
Nvidia GPUs".
|
||||
enabled_by_default: boolean for default behavior.
|
||||
question: optional string for how to ask for user input.
|
||||
yes_reply: optional string for reply when feature is enabled.
|
||||
@ -456,8 +456,8 @@ def check_bazel_version(min_version, max_version):
|
||||
"""Check installed bazel version is between min_version and max_version.
|
||||
|
||||
Args:
|
||||
min_version: string for minimum bazel version.
|
||||
max_version: string for maximum bazel version.
|
||||
min_version: string for minimum bazel version (must exist!).
|
||||
max_version: string for maximum bazel version (must exist!).
|
||||
|
||||
Returns:
|
||||
The bazel version detected.
|
||||
@ -570,7 +570,7 @@ def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
|
||||
|
||||
Args:
|
||||
environ_cp: copy of the os.environ.
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
|
||||
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
|
||||
ask_for_var: string for how to ask for user input.
|
||||
var_default: default value string.
|
||||
|
||||
@ -1039,10 +1039,8 @@ def set_other_cuda_vars(environ_cp):
|
||||
# If CUDA is enabled, always use GPU during build and test.
|
||||
if environ_cp.get('TF_CUDA_CLANG') == '1':
|
||||
write_to_bazelrc('build --config=cuda_clang')
|
||||
write_to_bazelrc('test --config=cuda_clang')
|
||||
else:
|
||||
write_to_bazelrc('build --config=cuda')
|
||||
write_to_bazelrc('test --config=cuda')
|
||||
|
||||
|
||||
def set_host_cxx_compiler(environ_cp):
|
||||
@ -1261,7 +1259,8 @@ def set_windows_build_flags(environ_cp):
|
||||
write_to_bazelrc('build --copt=-w --host_copt=-w')
|
||||
# Fix winsock2.h conflicts
|
||||
write_to_bazelrc(
|
||||
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN')
|
||||
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
|
||||
'--copt=-DNOGDI --host_copt=-DNOGDI')
|
||||
# Output more verbose information when something goes wrong
|
||||
write_to_bazelrc('build --verbose_failures')
|
||||
# The host and target platforms are the same in Windows build. So we don't
|
||||
@ -1324,9 +1323,9 @@ def validate_cuda_config(environ_cp):
|
||||
|
||||
cuda_libraries = ['cuda', 'cudnn']
|
||||
if is_linux():
|
||||
if 'TF_TENSORRT_VERSION' in environ_cp: # if env variable exists
|
||||
if int(environ_cp.get('TF_NEED_TENSORRT', False)):
|
||||
cuda_libraries.append('tensorrt')
|
||||
if environ_cp.get('TF_NCCL_VERSION', None): # if env variable not empty
|
||||
if environ_cp.get('TF_NCCL_VERSION', None):
|
||||
cuda_libraries.append('nccl')
|
||||
|
||||
proc = subprocess.Popen(
|
||||
@ -1387,7 +1386,7 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
current_bazel_version = check_bazel_version('0.24.1', '0.25.0')
|
||||
current_bazel_version = check_bazel_version('0.24.1', '0.25.2')
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
@ -1453,8 +1452,12 @@ def main():
|
||||
cuda_env_names = [
|
||||
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
|
||||
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
|
||||
'CUDA_TOOLKIT_PATH'
|
||||
# Items below are for backwards compatibility when not using
|
||||
# TF_CUDA_PATHS.
|
||||
'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH',
|
||||
'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH'
|
||||
]
|
||||
# Note: set_action_env_var above already writes to bazelrc.
|
||||
for name in cuda_env_names:
|
||||
if name in environ_cp:
|
||||
write_action_env_to_bazelrc(name, environ_cp[name])
|
||||
@ -1493,7 +1496,6 @@ def main():
|
||||
else:
|
||||
# Use downloaded LLD for linking.
|
||||
write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld')
|
||||
write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld')
|
||||
else:
|
||||
# Set up which gcc nvcc should use as the host compiler
|
||||
# No need to set this on Windows
|
||||
@ -1506,7 +1508,6 @@ def main():
|
||||
set_tf_download_clang(environ_cp)
|
||||
if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
|
||||
write_to_bazelrc('build --config=download_clang')
|
||||
write_to_bazelrc('test --config=download_clang')
|
||||
|
||||
# SYCL / ROCm / CUDA are mutually exclusive.
|
||||
# At most 1 GPU platform can be configured.
|
||||
@ -1546,12 +1547,6 @@ def main():
|
||||
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
|
||||
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
|
||||
configure_ios()
|
||||
else:
|
||||
# TODO(pcloudy): Remove BAZEL_USE_CPP_ONLY_TOOLCHAIN after Bazel is upgraded
|
||||
# to 0.24.0.
|
||||
# For working around https://github.com/bazelbuild/bazel/issues/7607
|
||||
if is_macos():
|
||||
write_to_bazelrc('build --action_env=BAZEL_USE_CPP_ONLY_TOOLCHAIN=1')
|
||||
|
||||
print('Preconfigured Bazel build configs. You can use any of the below by '
|
||||
'adding "--config=<>" to your build command. See .bazelrc for more '
|
||||
|
@ -184,6 +184,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_aarch64",
|
||||
values = {"cpu": "aarch64"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_x86_64",
|
||||
values = {"cpu": "k8"},
|
||||
@ -420,6 +426,9 @@ config_setting(
|
||||
values = {"cpu": "x64_windows"},
|
||||
)
|
||||
|
||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||
# Instead, please use public APIs or public build rules TF provides.
|
||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
|
@ -32,10 +32,13 @@ from __future__ import print_function as _print_function
|
||||
|
||||
import distutils as _distutils
|
||||
import inspect as _inspect
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import site as _site
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
# Make sure directory containing top level submodules is in
|
||||
@ -49,25 +52,29 @@ if not hasattr(_current_module, '__path__'):
|
||||
elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=('tensorboard.summary._tf.summary'),
|
||||
error_msg="Limited tf.summary API due to missing TensorBoard installation")
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=(
|
||||
'tensorflow_estimator.python.estimator.api._v2.estimator'))
|
||||
# Hook external TensorFlow modules.
|
||||
try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not hasattr(_current_module, 'estimator'):
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=(
|
||||
'tensorflow_estimator.python.estimator.api.estimator'))
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=('tensorflow.python.keras.api._v2.keras'))
|
||||
|
||||
# Enable TF2 behaviors
|
||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||
|
@ -26,24 +26,37 @@ import sys as _sys
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=(
|
||||
'tensorflow_estimator.python.estimator.api._v1.estimator'))
|
||||
|
||||
# Make sure directory containing top level submodules is in
|
||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||
# We're using bitwise, but there's nothing special about that.
|
||||
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
||||
_current_module = _sys.modules[__name__]
|
||||
if not hasattr(_current_module, 'estimator'):
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=(
|
||||
'tensorflow_estimator.python.estimator.api.estimator'))
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=('tensorflow.python.keras.api._v1.keras'))
|
||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||
if not hasattr(_current_module, '__path__'):
|
||||
__path__ = [_tf_api_dir]
|
||||
elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
||||
_CONTRIB_WARNING = """
|
||||
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
|
||||
@ -66,17 +79,6 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
|
||||
# The 'app' module will be imported as part of the placeholder section above.
|
||||
app.flags = flags # pylint: disable=undefined-variable
|
||||
|
||||
# Also use 'app' module (choice is arbitrary) to derive the API directory below.
|
||||
_API_MODULE = app # pylint: disable=undefined-variable
|
||||
|
||||
# Make sure directory containing top level submodules is in
|
||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||
if not hasattr(_current_module, '__path__'):
|
||||
__path__ = [_tf_api_dir]
|
||||
elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
# running under pip.
|
||||
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
|
||||
|
@ -104,6 +104,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc:gradients",
|
||||
"//tensorflow/cc:ops",
|
||||
@ -145,6 +146,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_platform",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -243,6 +245,28 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "ops",
|
||||
srcs = [
|
||||
"ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ops.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}) + [":c_api_internal"],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
@ -291,7 +315,6 @@ tf_cuda_cc_test(
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = [
|
||||
"no_oss", # http://b/119522529
|
||||
"noasan",
|
||||
],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
@ -445,6 +468,27 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ops_test",
|
||||
size = "small",
|
||||
srcs = ["ops_test.cc"],
|
||||
linkopts = select({
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = ["noasan"],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
# the shared library must be able to use core:framework.
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":c_api",
|
||||
":ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python API target
|
||||
|
||||
|
@ -30,8 +30,8 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/while_loop.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/framework/logging.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/kernels/logging_ops.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -66,6 +67,24 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
|
||||
}
|
||||
}
|
||||
|
||||
unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
|
||||
tensorflow::BuildXlaOpsPassFlags* flags =
|
||||
tensorflow::GetBuildXlaOpsPassFlags();
|
||||
bool original = flags->tf_xla_enable_lazy_compilation;
|
||||
flags->tf_xla_enable_lazy_compilation = enable;
|
||||
return original;
|
||||
}
|
||||
|
||||
void TF_SetXLaAutoJitMode(const char* mode) {
|
||||
tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
|
||||
}
|
||||
|
||||
void TF_SetXlaMinClusterSize(int size) {
|
||||
tensorflow::MarkForCompilationPassFlags* flags =
|
||||
tensorflow::GetMarkForCompilationPassFlags();
|
||||
flags->tf_xla_min_cluster_size = size;
|
||||
}
|
||||
|
||||
TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
|
||||
unsigned char gpu_memory_allow_growth,
|
||||
unsigned int num_cpu_devices) {
|
||||
@ -676,7 +695,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer(
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
||||
std::move(server), grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr));
|
||||
|
||||
|
@ -62,6 +62,20 @@ extern "C" {
|
||||
TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
|
||||
unsigned char enable);
|
||||
|
||||
// Set XLA's internal BuildXlaOpsPassFlags.tf_xla_enable_lazy_compilation to the
|
||||
// value of 'enabled'. Also returns the original value of that flag.
|
||||
//
|
||||
// Use in tests to allow XLA to fallback to TF classic. This has global effect.
|
||||
TF_CAPI_EXPORT unsigned char TF_SetXlaEnableLazyCompilation(
|
||||
unsigned char enable);
|
||||
|
||||
// Sets XLA's auto jit mode according to the specified string, which is parsed
|
||||
// as if passed in XLA_FLAGS. This has global effect.
|
||||
TF_CAPI_EXPORT void TF_SetXLaAutoJitMode(const char* mode);
|
||||
|
||||
// Sets XLA's minimum cluster size. This has global effect.
|
||||
TF_CAPI_EXPORT void TF_SetXlaMinClusterSize(int size);
|
||||
|
||||
// Create a serialized tensorflow.ConfigProto proto, where:
|
||||
//
|
||||
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
|
||||
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -295,7 +295,8 @@ Status FillFunctionBody(
|
||||
}
|
||||
|
||||
// Graph to FunctionDef conversion. This code is closely modeled on the Python
|
||||
// code in tensorflow/python/framework/function.py.
|
||||
// function graph_to_function_def(), which is located in
|
||||
// tensorflow/python/framework/graph_to_function_def.py.
|
||||
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
||||
bool append_hash_to_fn_name,
|
||||
const std::vector<const Node*>& body_nodes,
|
||||
@ -352,6 +353,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
||||
argdef->set_type(node->output_type(idx));
|
||||
const string& input_name = node_names.GetInputName(node->name());
|
||||
argdef->set_name(input_name);
|
||||
auto& arg_attrs = (*fdef->mutable_arg_attr())[i];
|
||||
for (const auto& attr : node->attrs()) {
|
||||
// Only copy internal attributes. These attributes will be applied to
|
||||
// _Arg/Placeholder nodes when this FunctionDef is converted to graph, and
|
||||
// normal attributes for nodes cannot be applied to those _Arg/Placeholder
|
||||
// nodes.
|
||||
if (absl::StartsWith(attr.first, "_")) {
|
||||
arg_attrs.mutable_attr()->insert(attr);
|
||||
}
|
||||
}
|
||||
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
|
||||
}
|
||||
|
||||
|
@ -1278,6 +1278,46 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
|
||||
EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
|
||||
}
|
||||
|
||||
void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
|
||||
const char* attr_name, const char* attr_value,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
|
||||
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||
TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
|
||||
TF_NewGraph(), TF_DeleteGraph);
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
|
||||
TF_DeleteStatus);
|
||||
|
||||
TF_Operation* node;
|
||||
NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
|
||||
&node);
|
||||
|
||||
TF_Output inputs[] = {{node, 0}};
|
||||
TF_Output outputs[] = {};
|
||||
func_ = TF_GraphToFunction(
|
||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 1, inputs, 0, outputs,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
ASSERT_NE(func_, nullptr);
|
||||
|
||||
// Verify that FunctionDef ArgDef has attributes.
|
||||
ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
|
||||
auto arg_attrs = func_->fdef.arg_attr().find(0);
|
||||
ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
|
||||
auto iter = arg_attrs->second.attr().find("_test_attr");
|
||||
ASSERT_NE(iter, arg_attrs->second.attr().end());
|
||||
EXPECT_EQ(iter->second.s(), "value");
|
||||
}
|
||||
|
||||
TEST_F(CApiFunctionTest, SetGradientAndRun) {
|
||||
// Define the function and its grad
|
||||
DefineFunction(func_name_, &func_);
|
||||
|
@ -24,8 +24,10 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// clang-format off
|
||||
// Required for IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/platform/platform.h" // NO_LINT
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
|
@ -29,8 +29,7 @@ namespace checkpoint {
|
||||
|
||||
class TensorSliceReader;
|
||||
|
||||
CheckpointReader::CheckpointReader(const string& filename,
|
||||
TF_Status* out_status)
|
||||
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
|
||||
: reader_(nullptr),
|
||||
v2_reader_(nullptr),
|
||||
var_to_shape_map_(nullptr),
|
||||
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
|
||||
v2_reader_.reset(
|
||||
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
|
||||
if (!v2_reader_->status().ok()) {
|
||||
Set_TF_Status_from_Status(out_status, v2_reader_->status());
|
||||
Set_TF_Status_from_Status(status, v2_reader_->status());
|
||||
return;
|
||||
}
|
||||
auto result = BuildV2VarMaps();
|
||||
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
|
||||
} else {
|
||||
reader_.reset(new TensorSliceReader(filename));
|
||||
if (!reader_->status().ok()) {
|
||||
Set_TF_Status_from_Status(out_status, reader_->status());
|
||||
Set_TF_Status_from_Status(status, reader_->status());
|
||||
return;
|
||||
}
|
||||
var_to_shape_map_.reset(
|
||||
|
@ -39,7 +39,7 @@ class TensorSliceReader;
|
||||
// variables.
|
||||
class CheckpointReader {
|
||||
public:
|
||||
CheckpointReader(const string& filepattern, TF_Status* out_status);
|
||||
CheckpointReader(const string& filename, TF_Status* status);
|
||||
|
||||
bool HasTensor(const string& name) const;
|
||||
const string DebugString() const;
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Experimental extensions to the C API for eager execution of kernels.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
@ -258,3 +259,22 @@ filegroup(
|
||||
srcs = ["c_api.h"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||
# right now, remove this public rule when no longer needed (it should be
|
||||
# replaced by TF Lite)
|
||||
filegroup(
|
||||
name = "srcs",
|
||||
srcs = glob(
|
||||
[
|
||||
"*.cc",
|
||||
"*.h",
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_experimental.h",
|
||||
"*test*",
|
||||
],
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
75
tensorflow/c/eager/c_api.cc
Executable file → Normal file
75
tensorflow/c/eager/c_api.cc
Executable file → Normal file
@ -21,6 +21,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// clang-format off
|
||||
// Required for IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
@ -38,11 +43,15 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
@ -88,6 +97,7 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
return (d == nullptr) ? "cpu:0" : d->name();
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
@ -220,7 +230,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
||||
remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
|
||||
remote_eager_workers.get(), ctx->context->Async(), &remote_contexts));
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
||||
@ -239,12 +249,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
|
||||
return ctx->context.InitializeRemote(
|
||||
return ctx->context->InitializeRemote(
|
||||
std::move(server), std::move(remote_eager_workers),
|
||||
std::move(remote_device_mgr), remote_contexts, r, device_mgr,
|
||||
keep_alive_secs);
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
|
||||
TFE_TensorHandle* input) {
|
||||
@ -341,7 +352,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context.SetAsyncForThread(enable);
|
||||
status->status = ctx->context->SetAsyncForThread(enable);
|
||||
}
|
||||
|
||||
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
@ -381,16 +392,14 @@ void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
TF_DeviceList* list = new TF_DeviceList;
|
||||
ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
if (ctx->context.remote_device_mgr()) {
|
||||
ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
if (ctx->context->remote_device_mgr()) {
|
||||
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) {
|
||||
status->status = ctx->context.ClearCaches();
|
||||
}
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
@ -398,6 +407,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
const void* proto,
|
||||
size_t proto_len,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::ServerDef server_def;
|
||||
if (!server_def.ParseFromArray(proto, proto_len)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -406,11 +419,12 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
}
|
||||
status->status =
|
||||
UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
ctx->context.SetThreadLocalDevicePlacementPolicy(
|
||||
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
|
||||
@ -420,19 +434,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
ctx->context.GetDevicePlacementPolicy());
|
||||
ctx->context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
|
||||
status->status = ctx->context.AsyncWait();
|
||||
status->status = ctx->context->AsyncWait();
|
||||
}
|
||||
|
||||
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
|
||||
status->status = ctx->context.GetStatus();
|
||||
status->status = ctx->context->GetStatus();
|
||||
}
|
||||
|
||||
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
|
||||
ctx->context.ClearAsyncError();
|
||||
ctx->context->ClearAsyncError();
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
@ -592,7 +606,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
return new TFE_Op(ctx, name, false, types,
|
||||
new TFE_OpInferenceContext(op_def));
|
||||
}
|
||||
if (!ctx->context.FindFunctionByName(name)) {
|
||||
if (!ctx->context->FindFunctionByName(name)) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"'", name,
|
||||
"' is neither a type of a primitive operation nor a name "
|
||||
@ -890,7 +904,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
@ -907,26 +921,31 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
||||
return;
|
||||
}
|
||||
status->status = ctx->context.AddFunctionDef(function_def);
|
||||
status->status = ctx->context->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context.AddFunctionDef(function->fdef);
|
||||
status->status = ctx->context->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
return ctx->context.FindFunctionDef(name) != nullptr;
|
||||
return ctx->context->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
ctx->context.SetShouldStoreGraphs(true);
|
||||
ctx->context.SetShouldStoreStepStats(true);
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
ctx->context->SetShouldStoreStepStats(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
ctx->context.SetShouldStoreGraphs(false);
|
||||
ctx->context.SetShouldStoreStepStats(false);
|
||||
ctx->context->SetShouldStoreGraphs(false);
|
||||
ctx->context->SetShouldStoreStepStats(false);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@ -955,9 +974,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
|
||||
status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
|
||||
ctx->context.ClearRunMetadata();
|
||||
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
||||
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
||||
ctx->context->ClearRunMetadata();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -973,9 +992,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||
|
||||
namespace tensorflow {
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
|
@ -98,8 +98,7 @@ TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
|
||||
// Clears the internal caches in the TFE context. Useful when reseeding random
|
||||
// ops.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
|
||||
|
||||
// Sets a thread-local device placement policy. After this call, other calls to
|
||||
// TFE_Execute in the same thread will use the device policy specified here
|
||||
@ -411,6 +410,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
|
||||
TF_Function* function,
|
||||
TF_Status* status);
|
||||
|
||||
// Removes a function from the context. Once removed, you can no longer
|
||||
// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any
|
||||
// other function which calls it as an attribute.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextRemoveFunction(TFE_Context* ctx,
|
||||
const char* name,
|
||||
TF_Status* status);
|
||||
|
||||
// Checks whether a function is registered under `name`.
|
||||
TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx,
|
||||
const char* name);
|
||||
|
@ -63,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() {
|
||||
|
||||
void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
|
||||
TFE_Context* eager_context) {
|
||||
profiler_context->profiler_context.eager_context = &eager_context->context;
|
||||
profiler_context->profiler_context.eager_context = eager_context->context;
|
||||
}
|
||||
|
||||
void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
|
||||
@ -77,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) {
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context.SetShouldStoreGraphs(true);
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context.SetShouldStoreGraphs(false);
|
||||
ctx->context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
@ -99,59 +99,6 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
return s.ok();
|
||||
}
|
||||
|
||||
static tensorflow::mutex gauges_map_lock(tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
static std::unordered_map<string,
|
||||
tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
|
||||
get_gauges_map() EXCLUSIVE_LOCKS_REQUIRED(gauges_map_lock) {
|
||||
static std::unordered_map<
|
||||
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
|
||||
gauges_map = new std::unordered_map<
|
||||
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>;
|
||||
return gauges_map;
|
||||
}
|
||||
|
||||
static tensorflow::mutex samplers_map_lock(tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
|
||||
get_samplers_map() EXCLUSIVE_LOCKS_REQUIRED(samplers_map_lock) {
|
||||
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
|
||||
samplers_map =
|
||||
new std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>;
|
||||
return samplers_map;
|
||||
}
|
||||
|
||||
void TFE_MonitoringSetGauge(const char* name, const char* label,
|
||||
int64_t value) {
|
||||
tensorflow::mutex_lock l(gauges_map_lock);
|
||||
auto gauges_map = get_gauges_map();
|
||||
if (gauges_map->find(name) == gauges_map->end()) {
|
||||
gauges_map->emplace(
|
||||
name, tensorflow::monitoring::Gauge<tensorflow::int64, 1>::New(
|
||||
name,
|
||||
tensorflow::strings::StrCat(
|
||||
name, " :Gauge metric collected from Python API."),
|
||||
"metric_descriptor"));
|
||||
}
|
||||
gauges_map->at(name)->GetCell(label)->Set(value);
|
||||
}
|
||||
|
||||
void TFE_MonitoringAddSampler(const char* name, const char* label,
|
||||
double value) {
|
||||
tensorflow::mutex_lock l(samplers_map_lock);
|
||||
auto samplers_map = get_samplers_map();
|
||||
if (samplers_map->find(name) == samplers_map->end()) {
|
||||
samplers_map->emplace(
|
||||
name, tensorflow::monitoring::Sampler<1>::New(
|
||||
{name,
|
||||
tensorflow::strings::StrCat(
|
||||
name, " :Counter metric collected from Python API."),
|
||||
"metric_descriptor"},
|
||||
{tensorflow::monitoring::Buckets::Exponential(1, 2, 30)}));
|
||||
}
|
||||
samplers_map->at(name)->GetCell(label)->Add(value);
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
@ -166,6 +113,10 @@ TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
|
||||
const char* description) {
|
||||
auto* result = new TFE_MonitoringCounter0({name, description});
|
||||
Set_TF_Status_from_Status(status, result->counter->GetStatus());
|
||||
if (!result->counter->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -185,6 +136,10 @@ TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
|
||||
const char* label1) {
|
||||
auto* result = new TFE_MonitoringCounter1({name, description, label1});
|
||||
Set_TF_Status_from_Status(status, result->counter->GetStatus());
|
||||
if (!result->counter->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -206,6 +161,10 @@ TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
|
||||
auto* result =
|
||||
new TFE_MonitoringCounter2({name, description, label1, label2});
|
||||
Set_TF_Status_from_Status(status, result->counter->GetStatus());
|
||||
if (!result->counter->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -218,3 +177,344 @@ TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
|
||||
return static_cast<TFE_MonitoringCounterCell*>(
|
||||
static_cast<void*>(counter->counter->GetCell(label1, label2)));
|
||||
}
|
||||
|
||||
void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.Set(value);
|
||||
}
|
||||
|
||||
int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
|
||||
return cell->cell.value();
|
||||
}
|
||||
|
||||
TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description) {
|
||||
auto* result = new TFE_MonitoringIntGauge0({name, description});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
|
||||
TFE_MonitoringIntGauge0* gauge) {
|
||||
return static_cast<TFE_MonitoringIntGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell()));
|
||||
}
|
||||
|
||||
TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description,
|
||||
const char* label1) {
|
||||
auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
|
||||
TFE_MonitoringIntGauge1* gauge, const char* label1) {
|
||||
return static_cast<TFE_MonitoringIntGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell(label1)));
|
||||
}
|
||||
|
||||
TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description,
|
||||
const char* label1,
|
||||
const char* label2) {
|
||||
auto* result =
|
||||
new TFE_MonitoringIntGauge2({name, description, label1, label2});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
|
||||
TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
|
||||
return static_cast<TFE_MonitoringIntGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
|
||||
}
|
||||
|
||||
void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
|
||||
const char* value) {
|
||||
cell->cell.Set({value});
|
||||
}
|
||||
|
||||
const void TFE_MonitoringStringGaugeCellValue(
|
||||
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
|
||||
tensorflow::string value = cell->cell.value();
|
||||
void* data = tensorflow::port::Malloc(value.length());
|
||||
value.copy(static_cast<char*>(data), value.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = value.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
}
|
||||
|
||||
TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
|
||||
const char* name, TF_Status* status, const char* description) {
|
||||
auto* result = new TFE_MonitoringStringGauge0({name, description});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
|
||||
TFE_MonitoringStringGauge0* gauge) {
|
||||
return static_cast<TFE_MonitoringStringGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell()));
|
||||
}
|
||||
|
||||
TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
|
||||
const char* name, TF_Status* status, const char* description,
|
||||
const char* label1) {
|
||||
auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
|
||||
TFE_MonitoringStringGauge1* gauge, const char* label1) {
|
||||
return static_cast<TFE_MonitoringStringGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell(label1)));
|
||||
}
|
||||
|
||||
TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
|
||||
const char* name, TF_Status* status, const char* description,
|
||||
const char* label1, const char* label2) {
|
||||
auto* result =
|
||||
new TFE_MonitoringStringGauge2({name, description, label1, label2});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
|
||||
TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
|
||||
return static_cast<TFE_MonitoringStringGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
|
||||
}
|
||||
|
||||
void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
|
||||
bool value) {
|
||||
cell->cell.Set(value);
|
||||
}
|
||||
|
||||
bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
|
||||
return cell->cell.value();
|
||||
}
|
||||
|
||||
TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description) {
|
||||
auto* result = new TFE_MonitoringBoolGauge0({name, description});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
|
||||
TFE_MonitoringBoolGauge0* gauge) {
|
||||
return static_cast<TFE_MonitoringBoolGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell()));
|
||||
}
|
||||
|
||||
TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description,
|
||||
const char* label1) {
|
||||
auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
|
||||
TFE_MonitoringBoolGauge1* gauge, const char* label1) {
|
||||
return static_cast<TFE_MonitoringBoolGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell(label1)));
|
||||
}
|
||||
|
||||
TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description,
|
||||
const char* label1,
|
||||
const char* label2) {
|
||||
auto* result =
|
||||
new TFE_MonitoringBoolGauge2({name, description, label1, label2});
|
||||
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
|
||||
if (!result->gauge->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
|
||||
delete gauge;
|
||||
}
|
||||
|
||||
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
|
||||
TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
|
||||
return static_cast<TFE_MonitoringBoolGaugeCell*>(
|
||||
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
|
||||
}
|
||||
|
||||
void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
|
||||
double value) {
|
||||
cell->cell.Add(value);
|
||||
}
|
||||
|
||||
void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
|
||||
TF_Buffer* buf) {
|
||||
string content;
|
||||
cell->cell.value().SerializeToString(&content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = content.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
}
|
||||
|
||||
TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
|
||||
double growth_factor,
|
||||
int bucket_count) {
|
||||
return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
|
||||
return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
|
||||
bucket_count);
|
||||
});
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
|
||||
delete buckets;
|
||||
}
|
||||
|
||||
TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
|
||||
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
|
||||
const char* description) {
|
||||
auto* result = new TFE_MonitoringSampler0(
|
||||
{name, buckets->create_buckets(), description});
|
||||
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
|
||||
if (!result->sampler->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
|
||||
delete sampler;
|
||||
}
|
||||
|
||||
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
|
||||
TFE_MonitoringSampler0* sampler) {
|
||||
return static_cast<TFE_MonitoringSamplerCell*>(
|
||||
static_cast<void*>(sampler->sampler->GetCell()));
|
||||
}
|
||||
|
||||
TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
|
||||
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
|
||||
const char* description, const char* label1) {
|
||||
auto* result = new TFE_MonitoringSampler1(
|
||||
{name, buckets->create_buckets(), description, label1});
|
||||
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
|
||||
if (!result->sampler->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
|
||||
delete sampler;
|
||||
}
|
||||
|
||||
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
|
||||
TFE_MonitoringSampler1* sampler, const char* label1) {
|
||||
return static_cast<TFE_MonitoringSamplerCell*>(
|
||||
static_cast<void*>(sampler->sampler->GetCell(label1)));
|
||||
}
|
||||
|
||||
TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
|
||||
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
|
||||
const char* description, const char* label1, const char* label2) {
|
||||
auto* result = new TFE_MonitoringSampler2(
|
||||
{name, buckets->create_buckets(), description, label1, label2});
|
||||
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
|
||||
if (!result->sampler->GetStatus().ok()) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
|
||||
delete sampler;
|
||||
}
|
||||
|
||||
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
||||
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
|
||||
return static_cast<TFE_MonitoringSamplerCell*>(
|
||||
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
|
||||
}
|
||||
|
@ -87,19 +87,7 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
||||
const char* service_addr, const char* logdir, const char* worker_list,
|
||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts);
|
||||
|
||||
// Set the value of a Gauge metric. If the metric with given name does not
|
||||
// exist, it will create a new Gauge metric. Right now it only supports type
|
||||
// int64, consider to add more type supports if needed.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringSetGauge(const char* name,
|
||||
const char* label,
|
||||
int64_t value);
|
||||
|
||||
// Add the given value to a Sampler metric. If the metric with given name
|
||||
// does not exist, it will create a new Sampler metric.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringAddSampler(const char* name,
|
||||
const char* label,
|
||||
double value);
|
||||
|
||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Counter APIs.
|
||||
// These APIs de-templated monitoring Counter for swig.
|
||||
@ -149,6 +137,179 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
|
||||
TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Gauge APIs.
|
||||
// These APIs de-templated monitoring Gauge for swig.
|
||||
|
||||
typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell;
|
||||
|
||||
// Atomically set the value of the cell.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet(
|
||||
TFE_MonitoringIntGaugeCell* cell, int64_t value);
|
||||
|
||||
// Retrieves the current value of the cell.
|
||||
TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue(
|
||||
TFE_MonitoringIntGaugeCell* cell);
|
||||
|
||||
// APIs for Int Gauge without label.
|
||||
typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(
|
||||
const char* name, TF_Status* out_status, const char* description);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0(
|
||||
TFE_MonitoringIntGauge0* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
|
||||
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge);
|
||||
|
||||
// APIs for Int Gauge with 1 label.
|
||||
typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(
|
||||
const char* name, TF_Status* out_status, const char* description,
|
||||
const char* label1);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1(
|
||||
TFE_MonitoringIntGauge1* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
|
||||
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge,
|
||||
const char* label1);
|
||||
|
||||
// APIs for Int Gauge with 2 label.
|
||||
typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(
|
||||
const char* name, TF_Status* out_status, const char* description,
|
||||
const char* label1, const char* label2);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2(
|
||||
TFE_MonitoringIntGauge2* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
|
||||
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge,
|
||||
const char* label1, const char* label2);
|
||||
|
||||
typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell;
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet(
|
||||
TFE_MonitoringStringGaugeCell* cell, const char* value);
|
||||
// Retrieves the string value and saves it in buffer.
|
||||
TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue(
|
||||
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf);
|
||||
|
||||
// APIs for String Gauge without label.
|
||||
typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
|
||||
const char* name, TF_Status* out_status, const char* description);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0(
|
||||
TFE_MonitoringStringGauge0* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
|
||||
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge);
|
||||
|
||||
// APIs for String Gauge with 1 label.
|
||||
typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
|
||||
const char* name, TF_Status* out_status, const char* description,
|
||||
const char* label1);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1(
|
||||
TFE_MonitoringStringGauge1* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
|
||||
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge,
|
||||
const char* label1);
|
||||
|
||||
// APIs for String Gauge with 2 label.
|
||||
typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
|
||||
const char* name, TF_Status* out_status, const char* description,
|
||||
const char* label1, const char* label2);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2(
|
||||
TFE_MonitoringStringGauge2* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
|
||||
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge,
|
||||
const char* label1, const char* label2);
|
||||
|
||||
typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell;
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet(
|
||||
TFE_MonitoringBoolGaugeCell* cell, bool value);
|
||||
TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue(
|
||||
TFE_MonitoringBoolGaugeCell* cell);
|
||||
|
||||
// APIs for Bool Gauge without label.
|
||||
typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(
|
||||
const char* name, TF_Status* out_status, const char* description);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0(
|
||||
TFE_MonitoringBoolGauge0* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
|
||||
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge);
|
||||
|
||||
// APIs for Bool Gauge with 1 label.
|
||||
typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(
|
||||
const char* name, TF_Status* out_status, const char* description,
|
||||
const char* label1);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1(
|
||||
TFE_MonitoringBoolGauge1* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
|
||||
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge,
|
||||
const char* label1);
|
||||
|
||||
// APIs for Bool Gauge with 2 label.
|
||||
typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(
|
||||
const char* name, TF_Status* out_status, const char* description,
|
||||
const char* label1, const char* label2);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2(
|
||||
TFE_MonitoringBoolGauge2* gauge);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
|
||||
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge,
|
||||
const char* label1, const char* label2);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Sampler APIs.
|
||||
// These APIs de-templated monitoring Sampler for swig.
|
||||
|
||||
typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell;
|
||||
|
||||
// Atomically add the value of the cell.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd(
|
||||
TFE_MonitoringSamplerCell* cell, double value);
|
||||
|
||||
// Retrieves the current value of the cell. The return value is a HistogramProto
|
||||
// saved in buffer.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue(
|
||||
TFE_MonitoringSamplerCell* cell, TF_Buffer* buf);
|
||||
|
||||
// APIs for sampler buckets
|
||||
typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringBuckets*
|
||||
TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor,
|
||||
int bucket_count);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets(
|
||||
TFE_MonitoringBuckets* buckets);
|
||||
|
||||
// APIs for Sampler without label.
|
||||
typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
|
||||
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
|
||||
const char* description);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0(
|
||||
TFE_MonitoringSampler0* sampler);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
|
||||
TFE_MonitoringSampler0* sampler);
|
||||
|
||||
// APIs for Sampler with 1 label.
|
||||
typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
|
||||
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
|
||||
const char* description, const char* label1);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1(
|
||||
TFE_MonitoringSampler1* sampler);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
|
||||
TFE_MonitoringSampler1* sampler, const char* label1);
|
||||
|
||||
// APIs for Sampler with 2 label.
|
||||
typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
|
||||
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
|
||||
const char* description, const char* label1, const char* label2);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
|
||||
TFE_MonitoringSampler2* sampler);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
||||
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -131,23 +131,6 @@ TEST(CAPI, MultipleProfilerSession) {
|
||||
TFE_DeleteProfilerContext(profiler_context);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringSetGauge) {
|
||||
TFE_MonitoringSetGauge("test/gauge", "label", 1);
|
||||
auto* collection_registry = monitoring::CollectionRegistry::Default();
|
||||
monitoring::CollectionRegistry::CollectMetricsOptions options;
|
||||
std::unique_ptr<monitoring::CollectedMetrics> metrics =
|
||||
collection_registry->CollectMetrics(options);
|
||||
|
||||
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
|
||||
EXPECT_EQ(1,
|
||||
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
|
||||
|
||||
TFE_MonitoringSetGauge("test/gauge", "label", 5);
|
||||
metrics = collection_registry->CollectMetrics(options);
|
||||
EXPECT_EQ(5,
|
||||
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringCounter0) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* counter =
|
||||
@ -200,8 +183,59 @@ TEST(CAPI, MonitoringCounterMultiple) {
|
||||
TFE_MonitoringDeleteCounter2(counter2);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringAddSampler) {
|
||||
TFE_MonitoringAddSampler("test/sampler", "label", 1.0);
|
||||
TEST(CAPI, MonitoringGauge0) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* cell = TFE_MonitoringGetCellIntGauge0(gauge);
|
||||
TFE_MonitoringIntGaugeCellSet(cell, 1);
|
||||
EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1);
|
||||
auto* collection_registry = monitoring::CollectionRegistry::Default();
|
||||
monitoring::CollectionRegistry::CollectMetricsOptions options;
|
||||
std::unique_ptr<monitoring::CollectedMetrics> metrics =
|
||||
collection_registry->CollectMetrics(options);
|
||||
|
||||
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
|
||||
EXPECT_EQ(1,
|
||||
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
|
||||
|
||||
TFE_MonitoringIntGaugeCellSet(cell, 5);
|
||||
metrics = collection_registry->CollectMetrics(options);
|
||||
EXPECT_EQ(5,
|
||||
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringMultipleGauge) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* gauge1 =
|
||||
TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo");
|
||||
TFE_MonitoringBoolGaugeCellSet(cell1, true);
|
||||
EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1));
|
||||
|
||||
auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test",
|
||||
"label1", "label2");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar");
|
||||
TFE_MonitoringStringGaugeCellSet(cell2, "str");
|
||||
auto* buf = new TF_Buffer;
|
||||
TFE_MonitoringStringGaugeCellValue(cell2, buf);
|
||||
string data(static_cast<const char*>(buf->data), buf->length);
|
||||
delete buf;
|
||||
EXPECT_EQ(data, "str");
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringSampler0) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
|
||||
auto* sampler =
|
||||
TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* cell = TFE_MonitoringGetCellSampler0(sampler);
|
||||
TFE_MonitoringSamplerCellAdd(cell, 1.0);
|
||||
auto* collection_registry = monitoring::CollectionRegistry::Default();
|
||||
monitoring::CollectionRegistry::CollectMetricsOptions options;
|
||||
std::unique_ptr<monitoring::CollectedMetrics> metrics =
|
||||
@ -213,11 +247,48 @@ TEST(CAPI, MonitoringAddSampler) {
|
||||
->points.at(0)
|
||||
->histogram_value.sum());
|
||||
|
||||
TFE_MonitoringAddSampler("test/sampler", "label", 5.0);
|
||||
TFE_MonitoringSamplerCellAdd(cell, 5.0);
|
||||
metrics = collection_registry->CollectMetrics(options);
|
||||
EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
|
||||
->points.at(0)
|
||||
->histogram_value.sum());
|
||||
TFE_MonitoringDeleteBuckets(buckets);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringMultipleSampler) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
|
||||
auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status,
|
||||
"test", "label1");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo");
|
||||
TFE_MonitoringSamplerCellAdd(cell1, 1.0);
|
||||
TFE_MonitoringSamplerCellAdd(cell1, 2.0);
|
||||
TF_Buffer* result1 = TF_NewBuffer();
|
||||
TFE_MonitoringSamplerCellValue(cell1, result1);
|
||||
tensorflow::HistogramProto hitogram1;
|
||||
EXPECT_TRUE(hitogram1.ParseFromString(
|
||||
{reinterpret_cast<const char*>(result1->data), result1->length}));
|
||||
EXPECT_EQ(hitogram1.sum(), 3.0);
|
||||
delete result1;
|
||||
|
||||
auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status,
|
||||
"test", "label1", "label2");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar");
|
||||
TFE_MonitoringSamplerCellAdd(cell2, 2.0);
|
||||
TFE_MonitoringSamplerCellAdd(cell2, 3.0);
|
||||
TF_Buffer* result2 = TF_NewBuffer();
|
||||
TFE_MonitoringSamplerCellValue(cell2, result2);
|
||||
tensorflow::HistogramProto hitogram2;
|
||||
EXPECT_TRUE(hitogram2.ParseFromString(
|
||||
{reinterpret_cast<const char*>(result2->data), result2->length}));
|
||||
EXPECT_EQ(hitogram2.sum(), 5.0);
|
||||
delete result2;
|
||||
|
||||
TFE_MonitoringDeleteBuckets(buckets);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -36,20 +36,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
||||
@ -68,13 +62,16 @@ struct TFE_Context {
|
||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
tensorflow::Rendezvous* rendezvous,
|
||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
||||
: context(opts,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
default_policy),
|
||||
async, device_mgr, device_mgr_owned, rendezvous,
|
||||
custom_kernel_creator) {}
|
||||
: context(new tensorflow::EagerContext(
|
||||
opts,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
default_policy),
|
||||
async, device_mgr, device_mgr_owned, rendezvous,
|
||||
custom_kernel_creator)) {}
|
||||
|
||||
tensorflow::EagerContext context;
|
||||
~TFE_Context() { context->Unref(); }
|
||||
|
||||
tensorflow::EagerContext* context;
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
@ -114,7 +111,7 @@ struct TFE_Op {
|
||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
TFE_OpInferenceContext* inference_ctx)
|
||||
: operation(&ctx->context, op, is_function, t),
|
||||
: operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(inference_ctx) {}
|
||||
|
||||
tensorflow::EagerOperation operation;
|
||||
@ -159,6 +156,98 @@ struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||
};
|
||||
struct TFE_MonitoringStringGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||
};
|
||||
struct TFE_MonitoringBoolGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||
};
|
||||
|
||||
template <typename ValueType, int NumLabels>
|
||||
struct TFE_MonitoringGauge {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringGauge(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
gauge = absl::WrapUnique(
|
||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
create_buckets;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSamplerCell {
|
||||
tensorflow::monitoring::SamplerCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringSampler {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringSampler(
|
||||
const char* name,
|
||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||
const char* description, LabelDesc&&... label) {
|
||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||
{name, description, label...}, std::move(buckets)));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
|
@ -14,10 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// Delete tensors after context is deleted.
|
||||
TFE_DeleteTensorHandle(h0_task1);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(nareshmodi): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
|
||||
TestRemoteExecuteDeleteTensorAfterContext(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
|
||||
TestRemoteExecuteDeleteTensorAfterContext(true);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -1225,6 +1281,8 @@ TEST(CAPI, Function_ident_CPU) {
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
@ -1295,6 +1353,8 @@ TEST(CAPI, Function_ident_XLA_CPU) {
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
@ -1371,6 +1431,8 @@ void FunctionDefAndExecute(bool async) {
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
@ -1412,6 +1474,8 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteTensorHandle(retval[0]);
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
122
tensorflow/c/experimental/BUILD
Normal file
122
tensorflow/c/experimental/BUILD
Normal file
@ -0,0 +1,122 @@
|
||||
# Description:
|
||||
# Experimental C APIs for TensorFlow.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_copts",
|
||||
"tf_cuda_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
|
||||
tf_cuda_library(
|
||||
name = "rendezvous_internal",
|
||||
srcs = [
|
||||
"rendezvous.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"rendezvous.h",
|
||||
"rendezvous_internal.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "rendezvous",
|
||||
hdrs = [
|
||||
"rendezvous.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":rendezvous_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "network_internal",
|
||||
srcs = [
|
||||
"network.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"network.h",
|
||||
"network_internal.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [
|
||||
":rendezvous_internal",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "network",
|
||||
hdrs = [
|
||||
"network.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":network_internal",
|
||||
":rendezvous",
|
||||
"//tensorflow/c:c_api",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "network_test",
|
||||
size = "medium",
|
||||
srcs = ["network_test.cc"],
|
||||
tags = ["noasan"],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
# the shared library must be able to use core:framework.
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":network",
|
||||
":network_internal",
|
||||
":rendezvous",
|
||||
":rendezvous_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_session",
|
||||
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
166
tensorflow/c/experimental/network.cc
Normal file
166
tensorflow/c/experimental/network.cc
Normal file
@ -0,0 +1,166 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/network.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/experimental/network_internal.h"
|
||||
#include "tensorflow/c/experimental/rendezvous_internal.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
using tensorflow::ServerFactory;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/* static */ Status CGrpcServer::Create(
|
||||
const ServerDef& server_def,
|
||||
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*delete_function)(void*),
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function,
|
||||
join_function, delete_function);
|
||||
|
||||
GrpcServerOptions options;
|
||||
options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) {
|
||||
return new CRendezvousMgr(env, rendezvous_builder);
|
||||
};
|
||||
TF_RETURN_IF_ERROR(grpc_server->Init(options));
|
||||
TF_Status* tf_status = TF_NewStatus();
|
||||
grpc_server->SetContext(init_function(
|
||||
reinterpret_cast<const TF_GrpcServer*>(grpc_server), tf_status));
|
||||
TF_RETURN_IF_ERROR(tf_status->status);
|
||||
TF_DeleteStatus(tf_status);
|
||||
|
||||
out_server->reset(grpc_server);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CGrpcServer::Start() {
|
||||
Status status = GrpcServer::Start();
|
||||
TF_Status* tf_status = TF_NewStatus();
|
||||
(*start_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
|
||||
tf_status);
|
||||
status.Update(tf_status->status);
|
||||
TF_DeleteStatus(tf_status);
|
||||
return status;
|
||||
}
|
||||
|
||||
Status CGrpcServer::Stop() {
|
||||
Status status = GrpcServer::Stop();
|
||||
TF_Status* tf_status = TF_NewStatus();
|
||||
(*stop_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
|
||||
tf_status);
|
||||
status.Update(tf_status->status);
|
||||
TF_DeleteStatus(tf_status);
|
||||
return status;
|
||||
}
|
||||
|
||||
Status CGrpcServer::Join() {
|
||||
Status status = GrpcServer::Join();
|
||||
TF_Status* tf_status = TF_NewStatus();
|
||||
(*join_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
|
||||
tf_status);
|
||||
status.Update(tf_status->status);
|
||||
TF_DeleteStatus(tf_status);
|
||||
return status;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Factory that creates CGrpcServer instances.
|
||||
class CServerFactory : public ServerFactory {
|
||||
public:
|
||||
CServerFactory(bool (*accept_function)(const char*),
|
||||
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||
void (*start_function)(const TF_GrpcServer*, void*,
|
||||
TF_Status*),
|
||||
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*delete_function)(void*),
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder)
|
||||
: accept_function_(accept_function),
|
||||
init_function_(init_function),
|
||||
start_function_(start_function),
|
||||
stop_function_(stop_function),
|
||||
join_function_(join_function),
|
||||
delete_function_(delete_function),
|
||||
rendezvous_builder_(rendezvous_builder) {}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
TF_RETURN_IF_ERROR(CGrpcServer::Create(
|
||||
server_def, init_function_, start_function_, stop_function_,
|
||||
join_function_, delete_function_, rendezvous_builder_, out_server));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns true if and only if this factory can create a server
|
||||
// based on the given `server_def`.
|
||||
bool AcceptsOptions(const ServerDef& server_def) override {
|
||||
return (*accept_function_)(server_def.protocol().c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
bool (*accept_function_)(const char* protocol);
|
||||
void* (*init_function_)(const TF_GrpcServer*, TF_Status*);
|
||||
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||
void (*delete_function_)(void*);
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder_;
|
||||
};
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
// Server factory representation to use in C API.
|
||||
// Holds CServerFactory pointer.
|
||||
struct TF_GrpcServerFactory {
|
||||
::tensorflow::CServerFactory* factory;
|
||||
};
|
||||
|
||||
TF_GrpcServerFactory* TF_NewGrpcServerFactory(
|
||||
bool (*accept_function)(const char*),
|
||||
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*delete_function)(void*),
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder) {
|
||||
TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory;
|
||||
server_factory->factory = new ::tensorflow::CServerFactory(
|
||||
accept_function, init_function, start_function, stop_function,
|
||||
join_function, delete_function, rendezvous_builder);
|
||||
return server_factory;
|
||||
}
|
||||
|
||||
void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) {
|
||||
DCHECK_NE(server_factory, nullptr);
|
||||
delete server_factory;
|
||||
}
|
||||
|
||||
void TF_RegisterGrpcServerFactory(const char* server_type,
|
||||
TF_GrpcServerFactory* server_factory) {
|
||||
ServerFactory::Register(server_type, server_factory->factory);
|
||||
}
|
97
tensorflow/c/experimental/network.h
Normal file
97
tensorflow/c/experimental/network.h
Normal file
@ -0,0 +1,97 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/rendezvous.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for TensorFlow Networking.
|
||||
// NOTE: This API is unstable and almost certainly will change in the near
|
||||
// future.
|
||||
//
|
||||
// Users wishing to register a custom GrpcServer should call
|
||||
// TF_NewServerFactory and then TF_RegisterGrpcServerFactory.
|
||||
//
|
||||
// Example:
|
||||
// ```c++
|
||||
// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
|
||||
// rendezvous_init_function,
|
||||
// receive_from_remote_async_function,
|
||||
// rendezvous_delete_function);
|
||||
//
|
||||
// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
|
||||
// accept_function,
|
||||
// init_function,
|
||||
// start_function,
|
||||
// stop_function,
|
||||
// join_function,
|
||||
// delete_function,
|
||||
// rendezvous_builder);
|
||||
// TF_RegisterGrpcServerFactory("customfactory", factory);
|
||||
// ...
|
||||
// TF_DeleteGrpcServerFactory(factory);
|
||||
// ```
|
||||
|
||||
typedef struct TF_GrpcServerFactory TF_GrpcServerFactory;
|
||||
typedef struct TF_GrpcServerOptions TF_GrpcServerOptions;
|
||||
typedef struct TF_GrpcServer TF_GrpcServer;
|
||||
typedef struct TF_ServerContext {
|
||||
TF_GrpcServer* const server;
|
||||
void* context;
|
||||
} TF_ServerContext;
|
||||
|
||||
// Creates a new TF_GrpcServerFactory instance. Caller takes ownership
|
||||
// of TF_GrpcServerFactory instance and should deallocate it by calling
|
||||
// TF_GrpcDeleteServerFactory.
|
||||
// accept_function should return true if this ServerFactory can create
|
||||
// server instances for the given protocol name (for e.g. grpc+verbs).
|
||||
// GRPC servers created by this factory will call provided
|
||||
// init_function, start_function, stop_function, join_function and
|
||||
// delete_function.
|
||||
//
|
||||
// Note that clean shutdown is currently not implemented for GrpcServer.
|
||||
// So, stop_function will never be called now but may be in the future
|
||||
// when stop mechanism is supported.
|
||||
TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory(
|
||||
bool (*accept_function)(const char*),
|
||||
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*delete_function)(void*),
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder);
|
||||
|
||||
// Deletes TF_GrpcServerFactory instances.
|
||||
// Note that this function only deletes TF_GrpcServerFactory wrapper.
|
||||
// Actual underlying server factory would not be deleted and will
|
||||
// remain registered.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory(
|
||||
TF_GrpcServerFactory* server_factory);
|
||||
|
||||
// Registers provided server_factory for the given server_type.
|
||||
// server_type must be unique to the server factory.
|
||||
TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory(
|
||||
const char* server_type, TF_GrpcServerFactory* server_factory);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
|
77
tensorflow/c/experimental/network_internal.h
Normal file
77
tensorflow/c/experimental/network_internal.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/network.h"
|
||||
#include "tensorflow/c/experimental/rendezvous.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// GrpcServer implementation that forwards calls to callbacks.
|
||||
class CGrpcServer : public GrpcServer {
|
||||
protected:
|
||||
CGrpcServer(const ServerDef& server_def,
|
||||
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*delete_function)(void*))
|
||||
: GrpcServer(server_def, ::tensorflow::Env::Default()),
|
||||
start_function_(start_function),
|
||||
stop_function_(stop_function),
|
||||
join_function_(join_function),
|
||||
delete_function_(delete_function),
|
||||
context_(nullptr) {}
|
||||
|
||||
public:
|
||||
static Status Create(
|
||||
const ServerDef& server_def,
|
||||
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||
void (*delete_function)(void*),
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
|
||||
Status Start() override;
|
||||
Status Stop() override;
|
||||
Status Join() override;
|
||||
|
||||
~CGrpcServer() override { delete_function_(context_); }
|
||||
|
||||
protected:
|
||||
void SetContext(void* context) { context_ = context; }
|
||||
|
||||
private:
|
||||
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||
void (*delete_function_)(void*);
|
||||
void* context_;
|
||||
|
||||
friend class NetworksTest;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
|
256
tensorflow/c/experimental/network_test.cc
Normal file
256
tensorflow/c/experimental/network_test.cc
Normal file
@ -0,0 +1,256 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/network.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/network_internal.h"
|
||||
#include "tensorflow/c/experimental/rendezvous.h"
|
||||
#include "tensorflow/c/experimental/rendezvous_internal.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_session.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool accept_functionA(const char* protocol_name) {
|
||||
return strcmp(protocol_name, "grpc+A") == 0;
|
||||
}
|
||||
|
||||
bool accept_functionB(const char* protocol_name) {
|
||||
return strcmp(protocol_name, "grpc+B") == 0;
|
||||
}
|
||||
|
||||
struct SomeServerData {
|
||||
bool server_started = false;
|
||||
};
|
||||
|
||||
struct SomeRendezvousData {
|
||||
int test = 0;
|
||||
};
|
||||
|
||||
void* init_function(const TF_GrpcServer* server, TF_Status* status) {
|
||||
SomeServerData* server_data = new SomeServerData();
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return server_data;
|
||||
}
|
||||
|
||||
void start_function(const TF_GrpcServer* server, void* context,
|
||||
TF_Status* status) {
|
||||
auto* server_data = static_cast<SomeServerData*>(context);
|
||||
server_data->server_started = true;
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void stop_function(const TF_GrpcServer* server, void* context,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void join_function(const TF_GrpcServer* server, void* context,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void delete_function(void* context) {
|
||||
auto* server_data = static_cast<SomeServerData*>(context);
|
||||
delete server_data;
|
||||
}
|
||||
|
||||
void* rendezvous_init_function(void* server_context) {
|
||||
return new SomeRendezvousData();
|
||||
}
|
||||
|
||||
void Deallocator(void* data, size_t, void* arg) {
|
||||
tensorflow::cpu_allocator()->DeallocateRaw(data);
|
||||
*reinterpret_cast<bool*>(arg) = true;
|
||||
}
|
||||
|
||||
void receive_from_remote_async_function(TF_ParsedKey* key,
|
||||
TF_RendezvousArgs* args,
|
||||
TF_RendezvousDoneCallback* callback,
|
||||
void* context) {
|
||||
// Create dummy tensor
|
||||
const int num_bytes = 6 * sizeof(float);
|
||||
float* values =
|
||||
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
|
||||
EIGEN_MAX_ALIGN_BYTES, num_bytes));
|
||||
int64_t dims[] = {2, 3};
|
||||
bool deallocator_called = false;
|
||||
auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
|
||||
&Deallocator, &deallocator_called);
|
||||
callback->tensor = tensor;
|
||||
auto* tf_status = TF_NewStatus();
|
||||
TF_SetStatus(tf_status, TF_OK, "");
|
||||
callback->status = tf_status;
|
||||
TF_RendezvousDone(callback);
|
||||
TF_DeleteStatus(tf_status);
|
||||
TF_DeleteTensor(tensor);
|
||||
}
|
||||
|
||||
void rendezvous_delete_function(void* context) {
|
||||
auto* rendezvous_data = static_cast<SomeRendezvousData*>(context);
|
||||
delete rendezvous_data;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& protocol,
|
||||
const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol(protocol);
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
class NetworksTest : public ::testing::Test {
|
||||
public:
|
||||
~NetworksTest() override {}
|
||||
|
||||
SomeServerData* GetServerData(CGrpcServer* server) {
|
||||
EXPECT_NE(server->context_, nullptr);
|
||||
return static_cast<SomeServerData*>(server->context_);
|
||||
}
|
||||
};
|
||||
|
||||
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
|
||||
const string& receiver, const string& name) {
|
||||
Rendezvous::ParsedKey result;
|
||||
CHECK(
|
||||
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
|
||||
name, FrameAndIter(0, 0)),
|
||||
&result)
|
||||
.ok());
|
||||
return result;
|
||||
}
|
||||
|
||||
void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def,
|
||||
RemoteRendezvous* remote_rendezvous) {
|
||||
int rendezvous_id = 0;
|
||||
auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id);
|
||||
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, *server_def, true));
|
||||
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get()));
|
||||
}
|
||||
|
||||
TEST_F(NetworksTest, TestStartServer) {
|
||||
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
|
||||
rendezvous_init_function, receive_from_remote_async_function,
|
||||
rendezvous_delete_function);
|
||||
|
||||
TF_Status* tf_status = TF_NewStatus();
|
||||
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
|
||||
accept_functionA, init_function, start_function, stop_function,
|
||||
join_function, delete_function, rendezvous_builder);
|
||||
TF_RegisterGrpcServerFactory("testfactoryA", factory);
|
||||
|
||||
ServerDef server_def = GetServerDef("grpc+A", "localhost", 1);
|
||||
std::unique_ptr<ServerInterface> server;
|
||||
TF_EXPECT_OK(NewServer(server_def, &server));
|
||||
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
|
||||
auto* server_data = GetServerData(grpc_server);
|
||||
ASSERT_FALSE(server_data->server_started);
|
||||
|
||||
TF_EXPECT_OK(server->Start());
|
||||
ASSERT_TRUE(server_data->server_started);
|
||||
|
||||
TF_DeleteStatus(tf_status);
|
||||
TF_DeleteGrpcServerFactory(factory);
|
||||
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
|
||||
// TODO(annarev): find a clean way to shutdown server.
|
||||
server.release();
|
||||
}
|
||||
|
||||
TEST_F(NetworksTest, TestReceiveData) {
|
||||
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
|
||||
rendezvous_init_function, receive_from_remote_async_function,
|
||||
rendezvous_delete_function);
|
||||
|
||||
TF_Status* tf_status = TF_NewStatus();
|
||||
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
|
||||
accept_functionB, init_function, start_function, stop_function,
|
||||
join_function, delete_function, rendezvous_builder);
|
||||
TF_RegisterGrpcServerFactory("testfactoryB", factory);
|
||||
|
||||
ServerDef server_def = GetServerDef("grpc+B", "localhost", 1);
|
||||
std::unique_ptr<ServerInterface> server;
|
||||
TF_EXPECT_OK(NewServer(server_def, &server));
|
||||
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
|
||||
|
||||
TF_EXPECT_OK(server->Start());
|
||||
auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr;
|
||||
auto* remote_rendezvous = rendezvous_mgr->Find(0);
|
||||
|
||||
auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", "test");
|
||||
Rendezvous::Args args;
|
||||
bool done_callback_called = false;
|
||||
auto* done_callback_called_ptr = &done_callback_called;
|
||||
absl::Notification notification;
|
||||
auto* notification_ptr = ¬ification;
|
||||
|
||||
InitializeRendezvous(grpc_server, &server_def, remote_rendezvous);
|
||||
remote_rendezvous->RecvAsync(
|
||||
key, args,
|
||||
[done_callback_called_ptr, notification_ptr](
|
||||
const Status&, const Rendezvous::Args&, const Rendezvous::Args&,
|
||||
const Tensor&, const bool) mutable {
|
||||
*done_callback_called_ptr = true;
|
||||
notification_ptr->Notify();
|
||||
});
|
||||
notification.WaitForNotificationWithTimeout(absl::Seconds(10));
|
||||
ASSERT_EQ(done_callback_called, true);
|
||||
|
||||
TF_DeleteStatus(tf_status);
|
||||
TF_DeleteGrpcServerFactory(factory);
|
||||
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
|
||||
// Server doesn't have a clean shutdown.
|
||||
server.release();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
124
tensorflow/c/experimental/rendezvous.cc
Normal file
124
tensorflow/c/experimental/rendezvous.cc
Normal file
@ -0,0 +1,124 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/rendezvous.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/experimental/rendezvous_internal.h"
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||
void (*receive_from_remote_async_function)(
|
||||
TF_ParsedKey*, TF_RendezvousArgs*,
|
||||
TF_RendezvousDoneCallback*,
|
||||
void* context),
|
||||
void (*delete_function)(void* context),
|
||||
void* server_context)
|
||||
: BaseRemoteRendezvous(env, step_id),
|
||||
receive_from_remote_async_function_(receive_from_remote_async_function),
|
||||
delete_function_(delete_function),
|
||||
context_(nullptr) {}
|
||||
|
||||
void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
TF_ParsedKey key;
|
||||
key.src_device = parsed.src_device.data();
|
||||
key.src_device_len = parsed.src_device.size();
|
||||
key.dst_device = parsed.dst_device.data();
|
||||
key.dst_device_len = parsed.dst_device.size();
|
||||
key.full_key = parsed.FullKey().data();
|
||||
key.full_key_len = parsed.FullKey().size();
|
||||
|
||||
TF_DeviceContext* device_context = new TF_DeviceContext();
|
||||
device_context->context = args.device_context;
|
||||
|
||||
TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes();
|
||||
alloc_attrs->value = args.alloc_attrs.value;
|
||||
alloc_attrs->scope_id = args.alloc_attrs.scope_id;
|
||||
alloc_attrs->on_host = args.alloc_attrs.on_host();
|
||||
alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible();
|
||||
|
||||
TF_RendezvousArgs* cargs = new TF_RendezvousArgs();
|
||||
cargs->device_context = device_context;
|
||||
cargs->alloc_attrs = alloc_attrs;
|
||||
|
||||
TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback();
|
||||
done_callback->done_callback = done;
|
||||
done_callback->recv_args = cargs;
|
||||
|
||||
receive_from_remote_async_function_(&key, cargs, done_callback, context_);
|
||||
}
|
||||
|
||||
CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); }
|
||||
} // namespace tensorflow
|
||||
|
||||
TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
|
||||
void* (*init_function)(void* server_context),
|
||||
void (*receive_from_remote_async_function)(TF_ParsedKey*,
|
||||
TF_RendezvousArgs*,
|
||||
TF_RendezvousDoneCallback*,
|
||||
void* context),
|
||||
void (*delete_function)(void* context)) {
|
||||
TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder();
|
||||
builder->init_function = init_function;
|
||||
builder->delete_function = delete_function;
|
||||
builder->receive_from_remote_async_function =
|
||||
receive_from_remote_async_function;
|
||||
return builder;
|
||||
}
|
||||
|
||||
void TF_DeleteRemoteRendezvousBuilder(
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder) {
|
||||
DCHECK_NE(rendezvous_builder, nullptr);
|
||||
delete rendezvous_builder;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_RendezvousDone(
|
||||
TF_RendezvousDoneCallback* callback) {
|
||||
DCHECK_NE(callback, nullptr);
|
||||
::tensorflow::Tensor tensor;
|
||||
TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor));
|
||||
::tensorflow::Rendezvous::Args recv_args;
|
||||
recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value;
|
||||
recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id;
|
||||
recv_args.device_context = callback->recv_args->device_context->context;
|
||||
::tensorflow::Rendezvous::Args sent_args;
|
||||
|
||||
callback->done_callback(callback->status->status, sent_args, recv_args,
|
||||
tensor, callback->dead);
|
||||
|
||||
if (callback->recv_args) {
|
||||
DCHECK_NE(callback->recv_args, nullptr);
|
||||
DCHECK_NE(callback->recv_args->alloc_attrs, nullptr);
|
||||
DCHECK_NE(callback->recv_args->device_context, nullptr);
|
||||
delete callback->recv_args->alloc_attrs;
|
||||
delete callback->recv_args->device_context;
|
||||
delete callback->recv_args;
|
||||
}
|
||||
delete callback;
|
||||
callback = nullptr;
|
||||
}
|
67
tensorflow/c/experimental/rendezvous.h
Normal file
67
tensorflow/c/experimental/rendezvous.h
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for Rendezvous.
|
||||
// NOTE: This API is unstable and almost certainly will change in the near
|
||||
// future.
|
||||
//
|
||||
// Custom rendezvous allows for custom implementations of Recv call.
|
||||
//
|
||||
// Users wishing to create custom rendezvous objects should call
|
||||
// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder
|
||||
// to to TF_NewServerFactory.
|
||||
|
||||
typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder;
|
||||
typedef struct TF_ParsedKey TF_ParsedKey;
|
||||
typedef struct TF_RendezvousArgs TF_RendezvousArgs;
|
||||
typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback;
|
||||
|
||||
// Creates a new TF_RemoteRendezvousBuilder instance.
|
||||
// Rendezvous instances will forward calls to init_function,
|
||||
// receive_from_remote_async_function and delete_function passed here.
|
||||
//
|
||||
// Note that receive_from_remote_async_function implementation must call
|
||||
// TF_Done with the TF_DoneCallback passed as an argument.
|
||||
TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
|
||||
void* (*init_function)(void* server_context),
|
||||
void (*receive_from_remote_async_function)(TF_ParsedKey*,
|
||||
TF_RendezvousArgs*,
|
||||
TF_RendezvousDoneCallback*,
|
||||
void* context),
|
||||
void (*delete_function)(void* context));
|
||||
|
||||
// Deletes TF_RemoteRendezvousBuilder instances.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder(
|
||||
TF_RemoteRendezvousBuilder* rendezvous_builder);
|
||||
|
||||
// Calls TF_DoneCallback and destroys callback instance and
|
||||
// TF_DoneCallback members except `tensor` and `status`. Caller is
|
||||
// responsible for deleting `tensor` and `status` after TF_Done returns.
|
||||
TF_CAPI_EXPORT extern void TF_RendezvousDone(
|
||||
TF_RendezvousDoneCallback* callback);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
|
135
tensorflow/c/experimental/rendezvous_internal.h
Normal file
135
tensorflow/c/experimental/rendezvous_internal.h
Normal file
@ -0,0 +1,135 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/rendezvous.h"
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
struct TF_ParsedKey {
|
||||
// char* members might not be null-terminated.
|
||||
const char* src_device;
|
||||
size_t src_device_len;
|
||||
const char* dst_device;
|
||||
size_t dst_device_len;
|
||||
const char* full_key;
|
||||
size_t full_key_len;
|
||||
};
|
||||
|
||||
struct TF_AllocatorAttributes {
|
||||
bool on_host;
|
||||
bool nic_compatible;
|
||||
// NOTE: The upper 8 bits of the value are reserved for
|
||||
// device-specific uses. Implementors of a device can interpret these
|
||||
// upper 8 bits in device-specific ways, and ops implemented for those
|
||||
// devices are responsible for setting those 8 bits appropriately.
|
||||
tensorflow::uint32 value = 0;
|
||||
// EXPERIMENTAL: If this is greater than zero, then allocation is delegated to
|
||||
// a named special-purpose allocator on the same device.
|
||||
tensorflow::int32 scope_id = 0;
|
||||
};
|
||||
|
||||
struct TF_DeviceContext {
|
||||
::tensorflow::DeviceContext* context;
|
||||
};
|
||||
|
||||
struct TF_RendezvousArgs {
|
||||
const TF_DeviceContext* device_context;
|
||||
const TF_AllocatorAttributes* alloc_attrs;
|
||||
};
|
||||
|
||||
struct TF_RendezvousDoneCallback {
|
||||
::tensorflow::Rendezvous::DoneCallback done_callback;
|
||||
|
||||
// TODO(annarev): figure out if we should also support sent_args.
|
||||
const TF_RendezvousArgs* recv_args;
|
||||
TF_Tensor* tensor = nullptr;
|
||||
TF_Status* status;
|
||||
bool dead;
|
||||
};
|
||||
|
||||
struct TF_RemoteRendezvousBuilder {
|
||||
void* (*init_function)(void* server_context);
|
||||
void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*,
|
||||
TF_RendezvousDoneCallback*,
|
||||
void* context);
|
||||
void (*delete_function)(void* context);
|
||||
void* server_context;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class CRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
public:
|
||||
CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||
void (*receive_from_remote_async_function)(
|
||||
TF_ParsedKey*, TF_RendezvousArgs*,
|
||||
TF_RendezvousDoneCallback*, void* context),
|
||||
void (*delete_function)(void* context),
|
||||
void* server_context);
|
||||
|
||||
void SetContext(void* context) { context_ = context; }
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
private:
|
||||
~CRemoteRendezvous() override;
|
||||
|
||||
void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*,
|
||||
TF_RendezvousDoneCallback*,
|
||||
void* context);
|
||||
void (*delete_function_)(void* context);
|
||||
void* context_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous);
|
||||
};
|
||||
|
||||
class CRendezvousMgr : public BaseRendezvousMgr {
|
||||
public:
|
||||
CRendezvousMgr(const WorkerEnv* env,
|
||||
const TF_RemoteRendezvousBuilder* rendezvous_builder)
|
||||
: BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {}
|
||||
|
||||
protected:
|
||||
BaseRemoteRendezvous* Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) override {
|
||||
auto* rendezvous = new CRemoteRendezvous(
|
||||
worker_env, step_id,
|
||||
rendezvous_builder_->receive_from_remote_async_function,
|
||||
rendezvous_builder_->delete_function,
|
||||
rendezvous_builder_->server_context);
|
||||
|
||||
rendezvous->SetContext(rendezvous_builder_->init_function(
|
||||
rendezvous_builder_->server_context));
|
||||
return rendezvous;
|
||||
}
|
||||
|
||||
private:
|
||||
const TF_RemoteRendezvousBuilder* rendezvous_builder_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
|
326
tensorflow/c/ops.cc
Normal file
326
tensorflow/c/ops.cc
Normal file
@ -0,0 +1,326 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/ops.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
using ::tensorflow::DataType;
|
||||
using ::tensorflow::OpDef;
|
||||
using ::tensorflow::OpDeprecation;
|
||||
using ::tensorflow::OpShapeInferenceFn;
|
||||
using ::tensorflow::Set_TF_Status_from_Status;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::shape_inference::DimensionHandle;
|
||||
using ::tensorflow::shape_inference::InferenceContext;
|
||||
using ::tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
typedef struct TF_OpDefinitionBuilder {
|
||||
// The op definition proto representing the op.
|
||||
tensorflow::OpDef op_def;
|
||||
|
||||
// The shape inference function, or nullptr if none is provided for this op.
|
||||
OpShapeInferenceFn shape_inference_func;
|
||||
} TF_OpDefinitionBuilder;
|
||||
|
||||
TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
|
||||
auto* result = new TF_OpDefinitionBuilder;
|
||||
result->op_def.set_name(op_name);
|
||||
return result;
|
||||
}
|
||||
|
||||
void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
|
||||
delete builder;
|
||||
}
|
||||
|
||||
static void PopulateArg(OpDef::ArgDef* arg, const char* name,
|
||||
TF_DataType type) {
|
||||
arg->set_name(name);
|
||||
arg->set_type(static_cast<DataType>(type));
|
||||
}
|
||||
|
||||
void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
|
||||
const char* name, TF_DataType type) {
|
||||
PopulateArg(builder->op_def.add_input_arg(), name, type);
|
||||
}
|
||||
|
||||
void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
|
||||
const char* name, TF_DataType type) {
|
||||
PopulateArg(builder->op_def.add_output_arg(), name, type);
|
||||
}
|
||||
|
||||
#define DEFINE_BUILDER_BOOL_SETTER(func_name, builder_setter_name, arg_name) \
|
||||
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
|
||||
bool arg_name) { \
|
||||
builder->op_def.builder_setter_name(arg_name); \
|
||||
}
|
||||
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative, set_is_commutative, is_commutative)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate, set_is_aggregate, is_aggregate)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful, set_is_stateful, is_stateful)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput,
|
||||
set_allows_uninitialized_input,
|
||||
allows_unintialized_input)
|
||||
|
||||
static OpDef::AttrDef* AddAttribute(TF_OpDefinitionBuilder* builder,
|
||||
const char* name, const char* type_name) {
|
||||
OpDef::AttrDef* attr = builder->op_def.add_attr();
|
||||
attr->set_name(name);
|
||||
attr->set_type(type_name);
|
||||
return attr;
|
||||
}
|
||||
|
||||
#define DEFINE_ATTR_SETTER(attr_type, type_name, field_c_type, field_name) \
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##Attr( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name) { \
|
||||
AddAttribute(builder, name, type_name); \
|
||||
} \
|
||||
\
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##AttrWithDefaultValue( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name, \
|
||||
field_c_type field_name) { \
|
||||
OpDef::AttrDef* attr = AddAttribute(builder, name, type_name); \
|
||||
attr->mutable_default_value()->set_##field_name(field_name); \
|
||||
} \
|
||||
\
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name, \
|
||||
field_c_type field_name[], size_t n) { \
|
||||
OpDef::AttrDef* attr = AddAttribute(builder, name, "list(" type_name ")"); \
|
||||
for (int _i = 0; _i < n; ++_i) { \
|
||||
attr->mutable_default_value()->mutable_list()->add_##field_name( \
|
||||
field_name[_i]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##ListAttr( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name) { \
|
||||
TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
|
||||
builder, name, NULL, 0); \
|
||||
}
|
||||
|
||||
DEFINE_ATTR_SETTER(String, "string", const char*, s)
|
||||
DEFINE_ATTR_SETTER(Int, "int", int64_t, i)
|
||||
DEFINE_ATTR_SETTER(Float, "float", float, f)
|
||||
DEFINE_ATTR_SETTER(Bool, "bool", bool, b)
|
||||
|
||||
void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
|
||||
int version, const char* explanation) {
|
||||
OpDeprecation* dep = builder->op_def.mutable_deprecation();
|
||||
dep->set_version(version);
|
||||
dep->set_explanation(explanation);
|
||||
}
|
||||
|
||||
void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
::tensorflow::OpRegistry::Global()->Register(
|
||||
[builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
op_reg_data->op_def.Clear();
|
||||
op_reg_data->op_def.MergeFrom(builder->op_def);
|
||||
op_reg_data->shape_inference_fn = builder->shape_inference_func;
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// Calling ProcessRegistrations ensures that the cc_builder's finalize method
|
||||
// is called and that the builder can be deleted.
|
||||
Set_TF_Status_from_Status(
|
||||
status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations());
|
||||
|
||||
delete builder;
|
||||
}
|
||||
|
||||
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
||||
TF_OpDefinitionBuilder* builder,
|
||||
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
|
||||
TF_Status* status)) {
|
||||
builder->shape_inference_func =
|
||||
[shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
|
||||
TF_Status* c_status = TF_NewStatus();
|
||||
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
|
||||
shape_inference_func(c_ctx, c_status);
|
||||
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
|
||||
TF_DeleteStatus(c_status);
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
TF_ShapeHandle* TF_NewShapeHandle() {
|
||||
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
|
||||
}
|
||||
|
||||
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
|
||||
TF_ShapeInferenceContext* ctx, size_t size) {
|
||||
auto* handle = new ShapeHandle;
|
||||
*handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
|
||||
return reinterpret_cast<TF_ShapeHandle*>(handle);
|
||||
}
|
||||
|
||||
void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* first,
|
||||
TF_ShapeHandle* second,
|
||||
TF_ShapeHandle* result,
|
||||
TF_Status* status) {
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
||||
Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
|
||||
*reinterpret_cast<ShapeHandle*>(second),
|
||||
reinterpret_cast<ShapeHandle*>(result));
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
TF_DimensionHandle* TF_NewDimensionHandle() {
|
||||
return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
|
||||
}
|
||||
|
||||
int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
||||
return cc_ctx->num_inputs();
|
||||
}
|
||||
|
||||
void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
|
||||
TF_ShapeHandle* handle,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
||||
if (0 < i || i >= cc_ctx->num_inputs()) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
|
||||
}
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
|
||||
*cc_result = cc_ctx->input(i);
|
||||
}
|
||||
}
|
||||
|
||||
int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* handle) {
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
||||
return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
|
||||
}
|
||||
|
||||
void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
|
||||
TF_ShapeHandle* handle,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
||||
if (0 < i || i >= cc_ctx->num_outputs()) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
|
||||
}
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
|
||||
}
|
||||
}
|
||||
|
||||
void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
|
||||
if (handle == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
delete reinterpret_cast<ShapeHandle*>(handle);
|
||||
}
|
||||
|
||||
void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
|
||||
if (handle == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
delete reinterpret_cast<DimensionHandle*>(handle);
|
||||
}
|
||||
|
||||
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
|
||||
void TF_ShapeInferenceContext_GetAttr##func( \
|
||||
TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
|
||||
TF_Status* status) { \
|
||||
TF_SetStatus(status, TF_OK, ""); \
|
||||
cc_type v; \
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
|
||||
Status s = cc_ctx->GetAttr(attr_name, &v); \
|
||||
Set_TF_Status_from_Status(status, s); \
|
||||
if (s.ok()) { \
|
||||
*val = static_cast<c_type>(v); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
|
||||
|
||||
#define DEFINE_RANK_FUNC(func_name) \
|
||||
void TF_ShapeInferenceContext##func_name( \
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
|
||||
TF_ShapeHandle* result, TF_Status* status) { \
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
|
||||
auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \
|
||||
auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \
|
||||
Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \
|
||||
Set_TF_Status_from_Status(status, s); \
|
||||
}
|
||||
|
||||
DEFINE_RANK_FUNC(WithRank)
|
||||
DEFINE_RANK_FUNC(WithRankAtLeast)
|
||||
DEFINE_RANK_FUNC(WithRankAtMost)
|
||||
|
||||
int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* handle) {
|
||||
return reinterpret_cast<InferenceContext*>(ctx)->Rank(
|
||||
*reinterpret_cast<ShapeHandle*>(handle));
|
||||
}
|
||||
|
||||
void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* shape_handle, int64_t i,
|
||||
TF_DimensionHandle* result) {
|
||||
int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
|
||||
auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
|
||||
|
||||
if (i < -rank || i >= rank) {
|
||||
*cc_result = DimensionHandle();
|
||||
return;
|
||||
}
|
||||
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
||||
auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
|
||||
*cc_result = cc_ctx->Dim(*cc_shape_handle, i);
|
||||
}
|
||||
|
||||
int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
|
||||
return InferenceContext::ValueKnown(
|
||||
*reinterpret_cast<DimensionHandle*>(dim_handle));
|
||||
}
|
||||
|
||||
void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
|
||||
TF_Status* status) {
|
||||
Status s = ::tensorflow::shape_inference::UnknownShape(
|
||||
reinterpret_cast<InferenceContext*>(ctx));
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* shape_handle,
|
||||
int64_t start, int64_t end,
|
||||
TF_ShapeHandle* result,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
||||
auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
|
||||
Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
|
||||
start, end, cc_result);
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
|
||||
return InferenceContext::Value(
|
||||
*reinterpret_cast<DimensionHandle*>(dim_handle));
|
||||
}
|
407
tensorflow/c/ops.h
Normal file
407
tensorflow/c/ops.h
Normal file
@ -0,0 +1,407 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Routines for registering new ops and for implementing op shape inference
|
||||
// functions.
|
||||
//
|
||||
// This API is alpha software and is subject to change.
|
||||
//
|
||||
// REGISTRATION
|
||||
// ------------
|
||||
//
|
||||
// In order to register a new op, create a new TF_OpDefinitionBuilder:
|
||||
//
|
||||
// TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("OpName");
|
||||
//
|
||||
// Inputs, outputs and attributes can be added to the builder with the
|
||||
// corresponding functions, e.g.
|
||||
//
|
||||
// TF_OpDefinitionBuilderAddInput(builder, "input1: int32");
|
||||
// TF_OpDefinitionBuilderAddOutput(builder, "output1: int64");
|
||||
// TF_OpDefinitionBuilderAddAttr(builder, "attr: int32");
|
||||
//
|
||||
// The builder may then be registered with TensorFlow using the
|
||||
// TF_RegisterOpDefinition function. E.g.
|
||||
//
|
||||
// TF_Status* status = TF_NewStatus();
|
||||
// TF_RegisterOpDefinition(builder, &status);
|
||||
// if (TF_GetCode(status) != TF_OK) {
|
||||
// // handle error
|
||||
// }
|
||||
//
|
||||
// SHAPE INFERENCE
|
||||
// ---------------
|
||||
//
|
||||
// You can provide a shape inference function that TensorFlow will call when it
|
||||
// wants to understand the shape of outputs that the op will produce. Use the
|
||||
// TF_OpDefinitionBuilderSetShapeInferenceFunction function to register a shape
|
||||
// inference function pointer with TensorFlow. The following is an example of a
|
||||
// very simple shape inference function:
|
||||
//
|
||||
// void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
|
||||
// TF_ShapeHandle* input = TF_NewShapeHandle();
|
||||
// TF_ShapeInferenceContextGetInput(ctx, 0, input, status);
|
||||
// if (TF_GetCode(status) == TF_OK) {
|
||||
// TF_ShapeInferenceContextSetOutput(ctx, 0, input, status);
|
||||
// }
|
||||
// TF_DeleteShapeHandle(input);
|
||||
// }
|
||||
//
|
||||
// The following code registers the inference function with TensorFlow:
|
||||
//
|
||||
// TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
|
||||
//
|
||||
// For more details about shape inference, see the documentation for
|
||||
// TF_OpDefinitionBuilderSetShapeInferenceFunction.
|
||||
|
||||
#ifndef TENSORFLOW_C_OPS_H_
|
||||
#define TENSORFLOW_C_OPS_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct TF_DimensionHandle;
|
||||
struct TF_OpDefinitionBuilder;
|
||||
struct TF_ShapeHandle;
|
||||
struct TF_ShapeInferenceContext;
|
||||
|
||||
// Returns a newly allocated op definition builder for the given op name. The
|
||||
// returned builder may be customized with the `TF_OpDefinitionBuilder...`
|
||||
// functions and then registered with TensorFlow with TF_RegisterOpDefinition.
|
||||
//
|
||||
// The returned pointer is either freed by a call to TF_RegisterOpDefinition, or
|
||||
// can be manually deleted by TF_DeleteOpDefinitionBuilder if it is never
|
||||
// registered.
|
||||
TF_CAPI_EXPORT extern TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(
|
||||
const char* op_name);
|
||||
|
||||
// Registers the given op builder with TensorFlow. Indicates success or
|
||||
// otherwise in the given status.
|
||||
//
|
||||
// `builder` is freed whether the op was successfully registered or not. You
|
||||
// must call either this function or TF_DeleteOpDefinitionBuilder to free the
|
||||
// builder, but never both.
|
||||
TF_CAPI_EXPORT extern void TF_RegisterOpDefinition(
|
||||
TF_OpDefinitionBuilder* builder, TF_Status* status);
|
||||
|
||||
// Frees the given op definition builder. You must call either this function or
|
||||
// TF_RegisterOpDefinition to free the builder, but never both.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder(
|
||||
TF_OpDefinitionBuilder* builder);
|
||||
|
||||
//----------------------------------------------------
|
||||
// Attribute functions.
|
||||
|
||||
// Adds a string attribute with the given name to the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a string attribute with the given name and default value to the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, const char* value);
|
||||
|
||||
// Adds a string list attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a string list attribute with the given default values to the builder.
|
||||
// `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddStringListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, const char* values[],
|
||||
size_t n);
|
||||
|
||||
// Adds an integer attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds an integer attribute with the given name and default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, int64_t value);
|
||||
|
||||
// Adds an integer list attribute with the given name and no default value to
|
||||
// the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds an integer list attribute with the given name and default values to the
|
||||
// builder. `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddIntListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, int64_t values[],
|
||||
size_t n);
|
||||
|
||||
// Adds a float attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a float attribute with the given name and default value to the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, float value);
|
||||
|
||||
// Adds a float list attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a float list attribute with the given name and default values to the
|
||||
// builder. `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, float values[],
|
||||
size_t n);
|
||||
|
||||
// Adds a boolean attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a boolean attribute with the given name and default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, bool value);
|
||||
|
||||
// Adds a boolean list attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a boolean list attribute with the given name and default values to the
|
||||
// builder. `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddBoolListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, bool values[], size_t n);
|
||||
|
||||
// Adds the input with the given name and type to the op.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, TF_DataType type);
|
||||
|
||||
// Adds the output with the given name and type to the op.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput(
|
||||
TF_OpDefinitionBuilder* builder, const char* output, TF_DataType type);
|
||||
|
||||
// Sets the commutative property for the op built by the given builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative(
|
||||
TF_OpDefinitionBuilder* builder, bool is_commutative);
|
||||
|
||||
// Sets the is_aggregate property of the builder to the given value.
|
||||
//
|
||||
// If is_aggregate is true, then the operation produced by this builder accepts
|
||||
// N >= 2 inputs and produces 1 output all of the same type. Should be
|
||||
// associative and commutative, and produce output with the same shape as the
|
||||
// input. The optimizer may replace an aggregate op taking input from multiple
|
||||
// devices with a tree of aggregate ops that aggregate locally within each
|
||||
// device (and possibly within groups of nearby devices) before communicating.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsAggregate(
|
||||
TF_OpDefinitionBuilder* builder, bool is_aggregate);
|
||||
|
||||
// Sets the is_stateful property of the builder to the given value.
|
||||
//
|
||||
// The op built by this builder is stateful if its behavior depends on some
|
||||
// state beyond its input tensors (e.g. variable reading op) or if it has a
|
||||
// side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
|
||||
// must always produce the same output for the same input and have no
|
||||
// side-effects.
|
||||
//
|
||||
// By default Ops may be moved between devices. Stateful ops should either not
|
||||
// be moved, or should only be moved if that state can also be moved (e.g. via
|
||||
// some sort of save / restore). Stateful ops are guaranteed to never be
|
||||
// optimized away by Common Subexpression Elimination (CSE).
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsStateful(
|
||||
TF_OpDefinitionBuilder* builder, bool is_stateful);
|
||||
|
||||
// Sets the allows_uninitialized_input property of the operation built by this
|
||||
// builder.
|
||||
//
|
||||
// By default, all inputs to an Op must be initialized Tensors. Ops that may
|
||||
// initialize tensors for the first time should set this field to true, to allow
|
||||
// the Op to take an uninitialized Tensor as input.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetAllowsUninitializedInput(
|
||||
TF_OpDefinitionBuilder* builder, bool allows_uninitialized_input);
|
||||
|
||||
// Adds a deprecation warning for the given op. This indicates to the user that
|
||||
// `version` is the first TensorFlow GraphDef version for which the operation is
|
||||
// deprecated. `explanation` should contain the reason for the deprecation and
|
||||
// what to use instead.
|
||||
//
|
||||
// This function is only an indicator that the operation may disappear in a
|
||||
// version of TensorFlow after `version`. It does not affect op registration.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderDeprecated(
|
||||
TF_OpDefinitionBuilder* builder, int version, const char* explanation);
|
||||
|
||||
// Sets the shape inference function for the op.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
||||
TF_OpDefinitionBuilder* builder,
|
||||
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
|
||||
TF_Status* status));
|
||||
|
||||
//----------------------------------------------------
|
||||
// Functions for TF_ShapeInferenceContext.
|
||||
//
|
||||
// Functions for implementing shape inference functions. TensorFlow uses these
|
||||
// functions to determine the shape of tensors produced by an operation without
|
||||
// having to actually run the operation. If an operation chooses to provide a
|
||||
// shape inference function, it will be invoked by TensorFlow as needed.
|
||||
//
|
||||
// When invoked by TensorFlow, the shape inference function is provided with a
|
||||
// TF_ShapeInferenceContext pointer. The function's implementation will use the
|
||||
// accessor and mutator functions with names beginning with
|
||||
// TF_ShapeInferenceContext to examine the input state and determine the output
|
||||
// shape.
|
||||
|
||||
// Returns the number of inputs in the given shape inference context.
|
||||
TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextNumInputs(
|
||||
TF_ShapeInferenceContext* ctx);
|
||||
|
||||
// Returns a newly allocated shape handle. The shapes represented by these
|
||||
// handles may be queried or mutated with the corresponding
|
||||
// TF_ShapeInferenceContext... functions.
|
||||
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_NewShapeHandle();
|
||||
|
||||
// Places the ith input of the given shape inference context into the given
|
||||
// shape handle, or returns a status other than TF_OK indicating why the input
|
||||
// could not be retrieved
|
||||
// (for example, if i < 0 || i >= TF_ShapeInferenceContextNumInputs(ctx)).
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextGetInput(
|
||||
TF_ShapeInferenceContext* ctx, int i, TF_ShapeHandle* handle,
|
||||
TF_Status* status);
|
||||
|
||||
// Places the given shape handle into the `i`th output position of the given
|
||||
// context. Internally, the shape handle is copied; the caller may subsequently
|
||||
// delete `handle`.
|
||||
TF_CAPI_EXPORT
|
||||
extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
|
||||
int i, TF_ShapeHandle* handle,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns a newly-allocate shape handle representing a vector of the given
|
||||
// size. The returned handle should be freed with TF_DeleteShapeHandle.
|
||||
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
|
||||
TF_ShapeInferenceContext* ctx, size_t size);
|
||||
|
||||
// Returns a newly allocated dimension handle. It must be freed with
|
||||
// TF_DeleteDimensionHandle.
|
||||
TF_CAPI_EXPORT extern TF_DimensionHandle* TF_NewDimensionHandle();
|
||||
|
||||
// Interprets the named shape inference context attribute as a TF_DataType and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// TF_DataType, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContext_GetAttrType(
|
||||
TF_ShapeInferenceContext* ctx, const char* attr_name, TF_DataType* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the rank of the shape represented by the given handle.
|
||||
TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextRank(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
|
||||
|
||||
// Returns 1 if `handle` has a known rank, 0 otherwise.
|
||||
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextRankKnown(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
|
||||
|
||||
// If <handle> has rank <rank>, or its rank is unknown, return OK and return the
|
||||
// shape with asserted rank in <*result>. Otherwise an error is placed into
|
||||
// `status`.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRank(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
|
||||
TF_ShapeHandle* result, TF_Status* status);
|
||||
|
||||
// If <handle> has rank at least <rank>, or its rank is unknown, return OK and
|
||||
// return the shape with asserted rank in <*result>. Otherwise an error is
|
||||
// placed into `status`.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtLeast(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
|
||||
TF_ShapeHandle* result, TF_Status* status);
|
||||
|
||||
// If <handle> has rank at most <rank>, or its rank is unknown, return OK and
|
||||
// return the shape with asserted rank in <*result>. Otherwise an error is
|
||||
// placed into `status`.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtMost(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
|
||||
TF_ShapeHandle* result, TF_Status* status);
|
||||
|
||||
// Places a handle to the ith dimension of the given shape into *result.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i,
|
||||
TF_DimensionHandle* result);
|
||||
|
||||
// Returns 1 if the given handle represents a known dimension.
|
||||
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextDimValueKnown(
|
||||
TF_ShapeInferenceContext* ctx, TF_DimensionHandle* handle);
|
||||
|
||||
// Returns in <*result> a sub-shape of <shape_handle>, with dimensions
|
||||
// [start:end]. <start> and <end> can be negative, to index from the end of the
|
||||
// shape. <start> and <end> are set to the rank of <shape_handle> if > rank of
|
||||
// <shape_handle>.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSubshape(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t start,
|
||||
int64_t end, TF_ShapeHandle* result, TF_Status* status);
|
||||
|
||||
// Places an unknown shape in all outputs for the given inference context. Used
|
||||
// for shape inference functions with ops whose output shapes are unknown.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSetUnknownShape(
|
||||
TF_ShapeInferenceContext* ctx, TF_Status* status);
|
||||
|
||||
// Returns whether the given handle represents a known dimension.
|
||||
TF_CAPI_EXPORT extern int TF_DimensionHandleValueKnown(
|
||||
TF_DimensionHandle* dim_handle);
|
||||
|
||||
// Returns the value of the given dimension.
|
||||
TF_CAPI_EXPORT extern int64_t TF_DimensionHandleValue(
|
||||
TF_DimensionHandle* dim_handle);
|
||||
|
||||
// Returns in <*result> the result of appending the dimensions of <second> to
|
||||
// those of <first>.
|
||||
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextConcatenateShapes(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* first,
|
||||
TF_ShapeHandle* second, TF_ShapeHandle* result, TF_Status* status);
|
||||
|
||||
// Frees the given shape handle.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteShapeHandle(TF_ShapeHandle* handle);
|
||||
|
||||
// Frees the given dimension handle.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteDimensionHandle(TF_DimensionHandle* handle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_OPS_H_
|
159
tensorflow/c/ops_test.cc
Normal file
159
tensorflow/c/ops_test.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/ops.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(OpsTest, TestBasicOpRegistration) {
|
||||
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
|
||||
TF_OpDefinitionBuilderAddStringAttr(builder, "attr1");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input2", TF_UINT16);
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_Buffer* op_list_buffer = TF_GetAllOpList();
|
||||
::tensorflow::OpList op_list;
|
||||
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
||||
bool found = false;
|
||||
for (const auto& op : op_list.op()) {
|
||||
if (op.name() == "SomeOp") {
|
||||
ASSERT_EQ(2, op.input_arg_size());
|
||||
ASSERT_EQ("input1", op.input_arg(0).name());
|
||||
ASSERT_EQ(::tensorflow::DT_UINT8, op.input_arg(0).type());
|
||||
ASSERT_EQ(1, op.attr_size());
|
||||
ASSERT_EQ("string", op.attr(0).type());
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(found);
|
||||
TF_DeleteStatus(status);
|
||||
TF_DeleteBuffer(op_list_buffer);
|
||||
}
|
||||
|
||||
void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
|
||||
TF_ShapeHandle* handle = TF_NewShapeHandle();
|
||||
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_ShapeInferenceContextSetOutput(ctx, 0, handle, status);
|
||||
TF_DeleteShapeHandle(handle);
|
||||
}
|
||||
|
||||
TEST(OpsTest, TestShapeInference_IdentityFunction) {
|
||||
ShapeInferenceTestOp op("SomeTestOp");
|
||||
|
||||
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_ASSERT_OK(
|
||||
shape_inference::ShapeInferenceTestutil::InferShapes(op, "[1,2]", "in0"));
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Creates an output whose shape is a vector of length
|
||||
// TF_ShapeInferenceContextRank.
|
||||
void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
|
||||
TF_ShapeHandle* handle = TF_NewShapeHandle();
|
||||
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_ShapeHandle* new_shape = TF_ShapeInferenceContextVectorFromSize(
|
||||
ctx, TF_ShapeInferenceContextRank(ctx, handle));
|
||||
TF_ShapeInferenceContextSetOutput(ctx, 0, new_shape, status);
|
||||
TF_DeleteShapeHandle(handle);
|
||||
TF_DeleteShapeHandle(new_shape);
|
||||
}
|
||||
|
||||
TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
||||
ShapeInferenceTestOp op("VectorizeTestOp");
|
||||
|
||||
TF_OpDefinitionBuilder* builder =
|
||||
TF_NewOpDefinitionBuilder("VectorizeTestOp");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
|
||||
op, "[4,5,9]", "[3]"));
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(OpsTest, AttributeAccessors) {
|
||||
TF_OpDefinitionBuilder* builder =
|
||||
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
|
||||
float values[] = {1, 2, 3, 4};
|
||||
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
|
||||
builder, "foo1", values, sizeof(values));
|
||||
TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(builder, "foo2",
|
||||
"my string");
|
||||
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
||||
TF_OpDefinitionBuilderSetIsAggregate(builder, true);
|
||||
TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
|
||||
std::string deprecation_msg = "use something else instead";
|
||||
TF_OpDefinitionBuilderDeprecated(builder, 4, deprecation_msg.c_str());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_Buffer* op_list_buffer = TF_GetAllOpList();
|
||||
::tensorflow::OpList op_list;
|
||||
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
||||
bool found = false;
|
||||
for (const auto& op : op_list.op()) {
|
||||
if (op.name() == "AttributeAccesorsOp") {
|
||||
ASSERT_TRUE(op.is_commutative());
|
||||
ASSERT_TRUE(op.is_aggregate());
|
||||
ASSERT_TRUE(op.allows_uninitialized_input());
|
||||
ASSERT_EQ(4, op.deprecation().version());
|
||||
ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
|
||||
ASSERT_EQ(2, op.attr_size());
|
||||
ASSERT_EQ("list(float)", op.attr(0).type());
|
||||
AttrValue::ListValue l = op.attr(0).default_value().list();
|
||||
ASSERT_EQ(1, l.f(0));
|
||||
ASSERT_EQ(2, l.f(1));
|
||||
ASSERT_EQ(3, l.f(2));
|
||||
ASSERT_EQ(4, l.f(3));
|
||||
|
||||
ASSERT_EQ("string", op.attr(1).type());
|
||||
ASSERT_EQ("my string", op.attr(1).default_value().s());
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(found);
|
||||
TF_DeleteStatus(status);
|
||||
TF_DeleteBuffer(op_list_buffer);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -203,6 +203,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -42,14 +42,19 @@ namespace {
|
||||
const int kRightMargin = 79;
|
||||
|
||||
// Converts:
|
||||
// bazel-out/.../genfiles/(external/YYY/)?XX
|
||||
// bazel-out/.../(bin|genfiles)/(external/YYY/)?XX
|
||||
// to: XX.
|
||||
string GetPath(const string& dot_h_fname) {
|
||||
auto pos = dot_h_fname.find("/genfiles/");
|
||||
auto pos = dot_h_fname.find("/bin/");
|
||||
string result = dot_h_fname;
|
||||
if (pos != string::npos) {
|
||||
// - 1 account for the terminating null character (\0) in "/genfiles/".
|
||||
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
|
||||
result = dot_h_fname.substr(pos + sizeof("/bin/") - 1);
|
||||
} else {
|
||||
pos = dot_h_fname.find("/genfiles/");
|
||||
if (pos != string::npos) {
|
||||
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
|
||||
}
|
||||
}
|
||||
if (result.size() > sizeof("external/") &&
|
||||
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
|
||||
|
@ -531,4 +531,23 @@ Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
|
||||
return InternalScope::NewScope(graph, status, refiner);
|
||||
}
|
||||
|
||||
Status CreateOutputWithScope(string op_name,
|
||||
absl::Span<const ::tensorflow::Input> inputs,
|
||||
const Scope& scope, Output* output) {
|
||||
TF_RETURN_IF_ERROR(scope.status());
|
||||
const auto unique_name = scope.GetUniqueNameForOp(op_name);
|
||||
auto builder = ::tensorflow::NodeBuilder(unique_name, op_name);
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(scope.status());
|
||||
builder = builder.Input(input.node());
|
||||
}
|
||||
::tensorflow::Node* ret;
|
||||
scope.UpdateBuilder(&builder);
|
||||
TF_RETURN_IF_ERROR(scope.status());
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
TF_RETURN_IF_ERROR(scope.status());
|
||||
*output = Output(ret, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -255,6 +255,12 @@ struct CompositeOpScopes {
|
||||
Scope last;
|
||||
};
|
||||
|
||||
// Creates a node of the given operation, with the given inputs, and assigns the
|
||||
// result to output. This does not support the ability to add additional
|
||||
// attributes.
|
||||
Status CreateOutputWithScope(string op_name,
|
||||
absl::Span<const ::tensorflow::Input> inputs,
|
||||
const Scope& scope, Output* output);
|
||||
/// @}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -145,4 +147,14 @@ TEST(ScopeTest, ControlDeps) {
|
||||
EXPECT_EQ(c_c.control_deps().size(), 3);
|
||||
}
|
||||
|
||||
TEST(ScopeTest, CreateOutput) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
|
||||
Output add;
|
||||
ASSERT_TRUE(
|
||||
CreateOutputWithScope("Add", {a, a}, root.WithOpName("add"), &add).ok());
|
||||
EXPECT_EQ(add.node()->name(), "add");
|
||||
EXPECT_EQ(add.node()->type_string(), "Add");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,27 +18,41 @@ from __future__ import absolute_import as _absolute_import
|
||||
from __future__ import division as _division
|
||||
from __future__ import print_function as _print_function
|
||||
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=('tensorboard.summary._tf.summary'),
|
||||
error_msg=(
|
||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||
"installation"))
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=(
|
||||
'tensorflow_estimator.python.estimator.api._v2.estimator'))
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=('tensorflow.python.keras.api._v2.keras'))
|
||||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||
"installation.")
|
||||
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
|
||||
#
|
||||
|
@ -19,18 +19,30 @@ from __future__ import division as _division
|
||||
from __future__ import print_function as _print_function
|
||||
|
||||
import os as _os
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=(
|
||||
'tensorflow_estimator.python.estimator.api._v1.estimator'))
|
||||
_component_api_helper.package_hook(
|
||||
parent_package_str=__name__,
|
||||
child_package_str=('tensorflow.python.keras.api._v1.keras'))
|
||||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
app.flags = flags # pylint: disable=undefined-variable
|
||||
|
@ -263,38 +263,23 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
|
||||
set_arg_data({{I}}, data);
|
||||
}
|
||||
)";
|
||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||
*methods += RewriteWithName(
|
||||
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
||||
{{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() {
|
||||
return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}}));
|
||||
}
|
||||
size_t num_results = ps.result().tuple_shapes_size();
|
||||
int variable_num = -1;
|
||||
for (int i = config.fetch_size(); i < num_results; ++i) {
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(AddRewritesForShape(
|
||||
i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
|
||||
string code = R"(
|
||||
{{TYPE}}* var_{{NAME}}_data() {
|
||||
return static_cast<{{TYPE}}*>(result_data({{I}}));
|
||||
}
|
||||
{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
|
||||
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
|
||||
result_data({{I}}))){{INDICES}};
|
||||
{{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
|
||||
return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>(
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
const {{TYPE}}* var_{{NAME}}_data() const {
|
||||
return static_cast<const {{TYPE}}*>(result_data({{I}}));
|
||||
return static_cast<const {{TYPE}}*>(arg_data({{I}}));
|
||||
}
|
||||
const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const {
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
result_data({{I}}))){{INDICES}};
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
)";
|
||||
do {
|
||||
++variable_num;
|
||||
} while (config.variable(variable_num).readonly());
|
||||
const tf2xla::Variable& var = config.variable(variable_num);
|
||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||
*methods += RewriteWithName(
|
||||
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
||||
}
|
||||
@ -549,7 +534,8 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return *kStaticData;
|
||||
}
|
||||
|
||||
{{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
|
||||
{{CLASS}}(AllocMode alloc_mode =
|
||||
AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
|
||||
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
|
||||
|
||||
{{CLASS}}(const {{CLASS}}&) = delete;
|
||||
@ -590,19 +576,29 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
// buffers are managed internally, and may change after each call to Run.
|
||||
{{METHODS_RESULT}}
|
||||
|
||||
// Methods for managing variable buffers. Buffers are in row-major order. The
|
||||
// input and output buffers may or may not be identical.
|
||||
// Methods for managing variable buffers. Buffers are in row-major order.
|
||||
//
|
||||
// For read-write variables we generate the following methods:
|
||||
//
|
||||
// void set_var_X_data(T* data)
|
||||
// Sets the buffer for variable X.
|
||||
// Sets the buffer for variable X. Must be called before Run if the
|
||||
// allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
|
||||
//
|
||||
// T* var_X_data()
|
||||
// Returns the buffer of type T for variable X.
|
||||
// Returns the buffer of type T for variable X. If the allocation mode is
|
||||
// RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
|
||||
// buffer passed to set_var_X_data.
|
||||
//
|
||||
// T& var_X(...dim indices...)
|
||||
// Returns a reference to the value of type T for variable X,
|
||||
// with dim indices specifying which value. No bounds checking is performed
|
||||
// on dim indices.
|
||||
//
|
||||
// For readonly variables we generate the same set of methods, except that we
|
||||
// use `const T` instead of `T`. We use `const T` to avoid erasing the
|
||||
// constness of the buffer passed to `set_var_X_data` but the underlying
|
||||
// buffer is not const (and thus the const can be safely const-cast'ed away)
|
||||
// unless `set_var_X_data` is called with a pointer to constant storage.
|
||||
{{METHODS_VARIABLE}}
|
||||
|
||||
private:
|
||||
|
@ -91,7 +91,8 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return *kStaticData;
|
||||
}
|
||||
|
||||
MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
|
||||
MyClass(AllocMode alloc_mode =
|
||||
AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
|
||||
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
|
||||
|
||||
MyClass(const MyClass&) = delete;
|
||||
@ -214,60 +215,82 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
result_data(0)))[dim0][dim1];
|
||||
}
|
||||
|
||||
// Methods for managing variable buffers. Buffers are in row-major order. The
|
||||
// input and output buffers may or may not be identical.
|
||||
// Methods for managing variable buffers. Buffers are in row-major order.
|
||||
//
|
||||
// For read-write variables we generate the following methods:
|
||||
//
|
||||
// void set_var_X_data(T* data)
|
||||
// Sets the buffer for variable X.
|
||||
// Sets the buffer for variable X. Must be called before Run if the
|
||||
// allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
|
||||
//
|
||||
// T* var_X_data()
|
||||
// Returns the buffer of type T for variable X.
|
||||
// Returns the buffer of type T for variable X. If the allocation mode is
|
||||
// RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
|
||||
// buffer passed to set_var_X_data.
|
||||
//
|
||||
// T& var_X(...dim indices...)
|
||||
// Returns a reference to the value of type T for variable X,
|
||||
// with dim indices specifying which value. No bounds checking is performed
|
||||
// on dim indices.
|
||||
//
|
||||
// For readonly variables we generate the same set of methods, except that we
|
||||
// use `const T` instead of `T`. We use `const T` to avoid erasing the
|
||||
// constness of the buffer passed to `set_var_X_data` but the underlying
|
||||
// buffer is not const (and thus the const can be safely const-cast'ed away)
|
||||
// unless `set_var_X_data` is called with a pointer to constant storage.
|
||||
|
||||
void set_var_myvar_readonly_data(const float* data) {
|
||||
set_arg_data(2, data);
|
||||
}
|
||||
const float* var_myvar_readonly_data() {
|
||||
return static_cast<const float*>(arg_data(2));
|
||||
}
|
||||
const float& var_myvar_readonly() {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(2)))[0];
|
||||
}
|
||||
const float* var_myvar_readonly_data() const {
|
||||
return static_cast<const float*>(arg_data(2));
|
||||
}
|
||||
const float& var_myvar_readonly() const {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(2)))[0];
|
||||
}
|
||||
|
||||
void set_var_myvar_data(float* data) {
|
||||
set_arg_data(3, data);
|
||||
}
|
||||
float* var_myvar_data() {
|
||||
return static_cast<float*>(arg_data(3));
|
||||
}
|
||||
float& var_myvar() {
|
||||
return (*static_cast<float(*)[1]>(
|
||||
arg_data(3)))[0];
|
||||
}
|
||||
const float* var_myvar_data() const {
|
||||
return static_cast<const float*>(arg_data(3));
|
||||
}
|
||||
const float& var_myvar() const {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(3)))[0];
|
||||
}
|
||||
|
||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||
set_arg_data(4, data);
|
||||
}
|
||||
|
||||
float* var_myvar_data() {
|
||||
return static_cast<float*>(result_data(1));
|
||||
}
|
||||
float& var_myvar() {
|
||||
return (*static_cast<float(*)[1]>(
|
||||
result_data(1)))[0];
|
||||
}
|
||||
const float* var_myvar_data() const {
|
||||
return static_cast<const float*>(result_data(1));
|
||||
}
|
||||
const float& var_myvar() const {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
result_data(1)))[0];
|
||||
}
|
||||
|
||||
tensorflow::int32* var_myvar2_data() {
|
||||
return static_cast<tensorflow::int32*>(result_data(2));
|
||||
return static_cast<tensorflow::int32*>(arg_data(4));
|
||||
}
|
||||
tensorflow::int32& var_myvar2(size_t dim0) {
|
||||
return (*static_cast<tensorflow::int32(*)[5]>(
|
||||
result_data(2)))[dim0];
|
||||
arg_data(4)))[dim0];
|
||||
}
|
||||
const tensorflow::int32* var_myvar2_data() const {
|
||||
return static_cast<const tensorflow::int32*>(result_data(2));
|
||||
return static_cast<const tensorflow::int32*>(arg_data(4));
|
||||
}
|
||||
const tensorflow::int32& var_myvar2(size_t dim0) const {
|
||||
return (*static_cast<const tensorflow::int32(*)[5]>(
|
||||
result_data(2)))[dim0];
|
||||
arg_data(4)))[dim0];
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -36,6 +36,7 @@ py_binary(
|
||||
name = "make_test_graphs",
|
||||
testonly = 1,
|
||||
srcs = ["make_test_graphs.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
@ -83,7 +83,8 @@ TEST(TFCompileTest, Add) {
|
||||
// Run tests that use set_argN_data separately, to avoid accidentally re-using
|
||||
// non-existent buffers.
|
||||
TEST(TFCompileTest, Add_SetArg) {
|
||||
AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
|
||||
AddComp add(
|
||||
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
|
||||
|
||||
int32 arg_x = 10;
|
||||
int32 arg_y = 32;
|
||||
@ -296,7 +297,7 @@ TEST(TFCompileTest, MatMul2_SetArg) {
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
foo::bar::MatMulComp matmul(
|
||||
foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
|
||||
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
|
||||
matmul.set_thread_pool(&device);
|
||||
|
||||
// Test using the set_argN_data() methods.
|
||||
@ -503,8 +504,36 @@ TEST(TFCompileTest, VariableSequentialUpdates) {
|
||||
|
||||
// This implements the recursion:
|
||||
// x[0] = 2.0
|
||||
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
|
||||
// x[n+1] = x[n] - 0.1*(x[n-1] + y)
|
||||
VariableSequentialUpdatesComp fn;
|
||||
fn.var_x() = 2;
|
||||
*const_cast<float*>(fn.var_y_data()) = 1;
|
||||
|
||||
fn.set_thread_pool(&device);
|
||||
// First calculate x[3]
|
||||
fn.Run();
|
||||
EXPECT_NEAR(fn.var_x(), 1.187f, 1e-6);
|
||||
|
||||
const float y = 1;
|
||||
fn.set_var_y_data(&y);
|
||||
|
||||
// Now const_cast<float*>(fn.var_y_data()) is not longer legal since we've set
|
||||
// the buffer to point to a constant location.
|
||||
|
||||
// Then calculate x[6]
|
||||
fn.Run();
|
||||
EXPECT_NEAR(fn.var_x(), 0.594322f, 1e-6);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
// This implements the recursion:
|
||||
// x[0] = 2.0
|
||||
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
|
||||
VariableSequentialUpdatesComp fn(
|
||||
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
|
||||
float x = 2;
|
||||
float y = 1;
|
||||
fn.set_var_x_data(&x);
|
||||
|
@ -174,6 +174,20 @@ def tf_library(
|
||||
"'" + arg.replace("'", "'\\''") + "'"
|
||||
for arg in (tfcompile_flags or [])
|
||||
])
|
||||
|
||||
# Do this before we append the `select` into `flags`, because doing so
|
||||
# transforms `flags` into a variable of type `select`, and we can't call
|
||||
# `find` on such an object.
|
||||
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
|
||||
|
||||
# Pass --target_cpu=haswell to tfcompile if compiling for Haswell (bazel
|
||||
# build --cpu=haswell). We put it at the beginning of the flags list so
|
||||
# that tfcompile_flags can override if if desired.
|
||||
flags = select({
|
||||
"//tools/target_cpu:haswell": "--target_cpu=haswell ",
|
||||
"//conditions:default": "",
|
||||
}) + flags
|
||||
|
||||
if enable_xla_hlo_profiling:
|
||||
profiling_flag = "--xla_hlo_profile"
|
||||
else:
|
||||
@ -251,7 +265,6 @@ def tf_library(
|
||||
|
||||
# The cc_library rule packaging up the header and object file, and needed
|
||||
# kernel implementations.
|
||||
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
|
||||
native.cc_library(
|
||||
name = name,
|
||||
srcs = [function_object_file, metadata_object_file],
|
||||
|
@ -17,15 +17,14 @@ package_group(
|
||||
package(
|
||||
default_visibility = [
|
||||
":internal",
|
||||
# BEGIN-GOOGLE-INTERNAL
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
# END-GOOGLE-INTERNAL
|
||||
],
|
||||
)
|
||||
|
||||
# NB! Removing the cc_header_only_library import breaks the OSS build since
|
||||
# copybara injects some build rules that use it.
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
|
||||
# Target that bundles up the XLA CPU and GPU JIT devices.
|
||||
@ -212,6 +211,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:optional_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
@ -223,6 +223,7 @@ cc_library(
|
||||
name = "shape_inference_helpers",
|
||||
srcs = ["shape_inference_helpers.cc"],
|
||||
hdrs = ["shape_inference_helpers.h"],
|
||||
visibility = [":friends"],
|
||||
deps = ["//tensorflow/core:graph"],
|
||||
)
|
||||
|
||||
@ -256,6 +257,11 @@ cc_library(
|
||||
name = "xla_launch_util",
|
||||
srcs = ["xla_launch_util.cc"],
|
||||
hdrs = ["xla_launch_util.h"],
|
||||
# TODO(skyewm): remove this once XlaAllocator is factored out.
|
||||
visibility = [
|
||||
":internal",
|
||||
"//tensorflow/compiler/xla/python:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_compilation_cache",
|
||||
@ -265,7 +271,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -273,6 +278,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -468,6 +474,9 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
@ -518,8 +527,9 @@ cc_library(
|
||||
"partially_decluster_pass.h",
|
||||
],
|
||||
deps = [
|
||||
"compilability_check_util",
|
||||
":common",
|
||||
":device_info_cache",
|
||||
":device_util",
|
||||
":encapsulate_util",
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
@ -581,21 +591,35 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_info_cache",
|
||||
srcs = ["device_info_cache.cc"],
|
||||
hdrs = ["device_info_cache.h"],
|
||||
name = "device_util",
|
||||
srcs = ["device_util.cc"],
|
||||
hdrs = ["device_util.h"],
|
||||
deps = [
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "device_util_test",
|
||||
srcs = ["device_util_test.cc"],
|
||||
deps = [
|
||||
":device_util",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -661,6 +685,7 @@ tf_cc_test(
|
||||
"introduce_floating_point_jitter_pass_test.cc",
|
||||
"mark_for_compilation_pass_test.cc",
|
||||
"partially_decluster_pass_test.cc",
|
||||
"rearrange_function_argument_pass_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
@ -681,6 +706,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:test_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
@ -764,6 +790,34 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compilability_check_util",
|
||||
srcs = ["compilability_check_util.cc"],
|
||||
hdrs = ["compilability_check_util.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":device_util",
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
":union_find",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "xla_ops_py",
|
||||
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -25,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/logging_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
@ -231,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
|
||||
}
|
||||
|
||||
// Returns true (into `result`) if a node placed on `device` must be compiled.
|
||||
Status DeviceRequiresCompilation(const string& device, bool* result) {
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(DeviceToDeviceType(device, &device_type));
|
||||
const XlaOpRegistry::DeviceRegistration* registration = nullptr;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
return errors::Internal("Could not find compilation device ",
|
||||
device_type.type());
|
||||
}
|
||||
Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
|
||||
jit::DeviceId device, bool* result) {
|
||||
const XlaOpRegistry::DeviceRegistration* registration =
|
||||
device_info_cache.GetCompilationDevice(device);
|
||||
*result = registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
return Status::OK();
|
||||
@ -291,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InferDeviceForCluster(Node* n, const string& function_name,
|
||||
const FunctionLibraryDefinition& flib_def,
|
||||
string* result) {
|
||||
xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
|
||||
jit::DeviceInfoCache* device_info_cache, Node* n,
|
||||
const string& function_name, const FunctionLibraryDefinition& flib_def) {
|
||||
const FunctionDef* func_def = flib_def.Find(function_name);
|
||||
TF_RET_CHECK(func_def) << "Could not find " << function_name;
|
||||
|
||||
std::set<string> device_names;
|
||||
jit::DeviceSet device_set;
|
||||
|
||||
for (const NodeDef& ndef : func_def->node_def()) {
|
||||
VLOG(3) << ndef.DebugString();
|
||||
if (!ndef.device().empty()) {
|
||||
device_names.insert(ndef.device());
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
|
||||
device_info_cache->GetIdFor(ndef.device()));
|
||||
device_set.Insert(device_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -309,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name,
|
||||
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
|
||||
// assignment when constant folding. We should fix EncapsulateSubgraphsPass
|
||||
// instead.
|
||||
device_names.insert(n->assigned_device_name());
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
|
||||
device_info_cache->GetIdFor(n->assigned_device_name()));
|
||||
device_set.Insert(device_id);
|
||||
}
|
||||
|
||||
std::vector<string> device_names_vector;
|
||||
absl::c_copy(device_names, std::back_inserter(device_names_vector));
|
||||
|
||||
Status s = PickDeviceForXla(device_names_vector, true, result);
|
||||
if (s.ok()) {
|
||||
VLOG(2) << "For " << function_name << " PickDeviceForXla("
|
||||
<< absl::StrJoin(device_names_vector, ", ") << ") -> " << *result;
|
||||
}
|
||||
return s;
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId result,
|
||||
PickDeviceForXla(*device_info_cache, device_set,
|
||||
/*allow_mixing_unknown_and_cpu=*/true));
|
||||
VLOG(2) << "For " << function_name << " PickDeviceForXla("
|
||||
<< device_info_cache->DebugString(device_set) << ") -> "
|
||||
<< device_info_cache->GetNameFor(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
jit::DeviceInfoCache* device_info_cache,
|
||||
const GraphOptimizationPassOptions& options,
|
||||
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
||||
bool insert_print_nodes, Graph* g, Node* n) {
|
||||
XlaClusterInfo cluster_info;
|
||||
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
||||
|
||||
string device;
|
||||
TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(),
|
||||
flib_def, &device));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
jit::DeviceId device,
|
||||
InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
|
||||
flib_def));
|
||||
|
||||
bool requires_compilation;
|
||||
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation));
|
||||
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
|
||||
&requires_compilation));
|
||||
if (!lazy_compilation_enabled) {
|
||||
requires_compilation = true;
|
||||
}
|
||||
|
||||
string device_name_str = string(device_info_cache->GetNameFor(device));
|
||||
|
||||
Status status;
|
||||
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
|
||||
.NewSubScope(n->name())
|
||||
.WithDevice(n->requested_device())
|
||||
.WithAssignedDevice(device);
|
||||
.WithAssignedDevice(device_name_str);
|
||||
|
||||
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
|
||||
/*constants=*/cluster_info.constant_inputs,
|
||||
@ -435,14 +442,16 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
bool lazy_compilation_enabled =
|
||||
enable_lazy_compilation_
|
||||
? *enable_lazy_compilation_
|
||||
: GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation;
|
||||
: GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
|
||||
bool insert_print_nodes =
|
||||
GetBuildXlaOpsPassFlags().tf_xla_print_cluster_outputs;
|
||||
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
|
||||
|
||||
jit::DeviceInfoCache device_info_cache;
|
||||
|
||||
for (Node* n : xla_compiled_kernels) {
|
||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
options, *options.flib_def, lazy_compilation_enabled,
|
||||
insert_print_nodes, graph, n));
|
||||
&device_info_cache, options, *options.flib_def,
|
||||
lazy_compilation_enabled, insert_print_nodes, graph, n));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
|
277
tensorflow/compiler/jit/compilability_check_util.cc
Normal file
277
tensorflow/compiler/jit/compilability_check_util.cc
Normal file
@ -0,0 +1,277 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
bool HasResourceInput(const Node& node) {
|
||||
return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) {
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
// is really a kind of function call and will be handled by
|
||||
// IsCompilableCall().
|
||||
if (node.type_string() == "SymbolicGradient") return false;
|
||||
if (node.type_string() == "Const") {
|
||||
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
|
||||
// registered Const KernelDef says that it does, to support no-op Assert for
|
||||
// tfcompile.
|
||||
const AttrValue* attr = node.attrs().Find("dtype");
|
||||
if (attr != nullptr && attr->type() == DT_STRING) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// XLA does not offer guaranteed aliasing between the input and output of the
|
||||
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
|
||||
// such nodes out of XLA clusters.
|
||||
if (HasForwardedRefInput(node)) {
|
||||
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok();
|
||||
}
|
||||
|
||||
// Tests whether 'while_node' is a completely compilable loop.
|
||||
// Every operator in the condition and body functions must be compilable for a
|
||||
// while loop to be compilable.
|
||||
bool RecursiveCompilabilityChecker::IsCompilableWhile(
|
||||
const Node& while_node, int depth, FunctionLibraryRuntime* lib_runtime) {
|
||||
const NameAttrList* name_attr;
|
||||
NodeDef call;
|
||||
Status status;
|
||||
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Rejecting While " << while_node.name()
|
||||
<< ": missing 'cond' attribute on While node.";
|
||||
return false;
|
||||
}
|
||||
const string cond_func = name_attr->name();
|
||||
call.set_name("while_cond");
|
||||
call.set_op(cond_func);
|
||||
*call.mutable_attr() = name_attr->attr();
|
||||
if (!IsCompilableCall(call, depth + 1, lib_runtime)) {
|
||||
VLOG(2) << "Rejecting While " << while_node.name()
|
||||
<< ": can't compile loop condition: " << cond_func;
|
||||
return false;
|
||||
}
|
||||
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Rejecting While " << while_node.name()
|
||||
<< ": missing 'body' attribute on While node.";
|
||||
return false;
|
||||
}
|
||||
const string body_func = name_attr->name();
|
||||
call.set_name("while_body");
|
||||
call.set_op(body_func);
|
||||
*call.mutable_attr() = name_attr->attr();
|
||||
if (!IsCompilableCall(call, depth + 1, lib_runtime)) {
|
||||
VLOG(2) << "Rejecting While " << while_node.name()
|
||||
<< ": can't compile loop body: " << body_func;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Tests whether 'call_def' is a call to a completely compilable function.
|
||||
// Every operator in the function must be compilable for a function to be
|
||||
// compilable.
|
||||
bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
const NodeDef& call_def, int depth, FunctionLibraryRuntime* lib_runtime) {
|
||||
if (depth > kMaxRecursionDepth) {
|
||||
VLOG(2) << "Rejecting " << call_def.op()
|
||||
<< ": function depth limit exceeded.";
|
||||
return false;
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Rejecting " << call_def.DebugString()
|
||||
<< ": could not instantiate: " << status;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto release_handle_on_return = gtl::MakeCleanup(
|
||||
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
|
||||
|
||||
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
||||
for (Node* node : fbody->graph->op_nodes()) {
|
||||
if (!IsCompilableNode(*node, depth + 1, lib_runtime)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LogNotCompilableAndReturn(const Node& node,
|
||||
absl::string_view reason = "") {
|
||||
VLOG(3) << "Not clustering " << node.name() << " (op " << node.type_string()
|
||||
<< ")" << (reason.empty() ? "" : ": ") << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) {
|
||||
// b/127344411: SelfAdjointEigV2 and Svd precision issues.
|
||||
return node.type_string() == "SelfAdjointEigV2" ||
|
||||
node.type_string() == "Svd";
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) {
|
||||
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
|
||||
return node.type_string() == "SelfAdjointEigV2" ||
|
||||
node.type_string() == "Svd" || node.type_string() == "Qr";
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
const Node& node, int depth, FunctionLibraryRuntime* lib_runtime) {
|
||||
if (node.IsSource() || node.IsSink()) {
|
||||
return LogNotCompilableAndReturn(node, "source or sink node");
|
||||
}
|
||||
|
||||
// _Arg nodes in a top-level function represent feeds and _Retval nodes in a
|
||||
// top-level function represent fetches.
|
||||
if (depth == 0 &&
|
||||
(node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
|
||||
return LogNotCompilableAndReturn(node, "depth is 0");
|
||||
}
|
||||
|
||||
if (node.attrs().Find("_scoped_allocator") ||
|
||||
node.attrs().Find("_forward_from")) {
|
||||
// TODO(b/128858118): XLA does not support _scoped_allocator and
|
||||
// _forward_from.
|
||||
return LogNotCompilableAndReturn(
|
||||
node, "_scoped_allocator or _forward_from attribute");
|
||||
}
|
||||
|
||||
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
|
||||
if (!IsCompilableCall(node.def(), depth + 1, lib_runtime)) {
|
||||
return LogNotCompilableAndReturn(node, "unsupported function");
|
||||
}
|
||||
} else if (!HasXLAKernel(node)) {
|
||||
return LogNotCompilableAndReturn(node, "unsupported op");
|
||||
}
|
||||
|
||||
if (node.type_string() == "While" &&
|
||||
!IsCompilableWhile(node, depth + 1, lib_runtime)) {
|
||||
return LogNotCompilableAndReturn(node, "unsupported while");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_stateful_rng_ops &&
|
||||
IsStatefulRandomOp(node.type_string())) {
|
||||
return LogNotCompilableAndReturn(node, "stateful random op");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
|
||||
return LogNotCompilableAndReturn(node);
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
|
||||
IsAssertOrCheckNumerics(node.type_string())) {
|
||||
return LogNotCompilableAndReturn(node, "Assert or CheckNumerics");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_ops_producing_or_consuming_variant &&
|
||||
OpProducesOrConsumesVariant(node)) {
|
||||
return LogNotCompilableAndReturn(node, "DT_VARIANT producer/consumer");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
|
||||
return LogNotCompilableAndReturn(node, "Stack op");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
|
||||
return LogNotCompilableAndReturn(node, "TensorArray op");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_resource_ops_in_called_functions && depth > 0 &&
|
||||
HasResourceInput(node)) {
|
||||
return LogNotCompilableAndReturn(node,
|
||||
"resource variable op in called function");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
|
||||
return LogNotCompilableAndReturn(node, "operation with correctness issues");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
|
||||
return LogNotCompilableAndReturn(node, "slow operation");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
const XlaOpRegistry::DeviceRegistration& registration) {
|
||||
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops_in_called_functions =
|
||||
registration.cluster_resource_variable_ops_unsafely;
|
||||
op_filter.allow_stack_ops = registration.cluster_stack_ops;
|
||||
op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
|
||||
op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
|
||||
op_filter.allow_control_trigger = registration.cluster_control_trigger;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops =
|
||||
registration.elide_assert_and_checknumerics;
|
||||
op_filter.allow_ops_producing_or_consuming_variant =
|
||||
registration.cluster_variant_ops;
|
||||
op_filter.allow_slow_and_inaccurate_ops =
|
||||
registration.cluster_slow_and_inaccurate_ops;
|
||||
return op_filter;
|
||||
}
|
||||
|
||||
|
||||
} // namespace tensorflow
|
175
tensorflow/compiler/jit/compilability_check_util.h
Normal file
175
tensorflow/compiler/jit/compilability_check_util.h
Normal file
@ -0,0 +1,175 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Checks whether a TF node can be compiled or not. "Recursive" as in for call
|
||||
// and functional while nodes it recursively checks whether the callee functions
|
||||
// can be compiled.
|
||||
class RecursiveCompilabilityChecker {
|
||||
public:
|
||||
// Aggregates information about what kinds of ops are allowed.
|
||||
struct OperationFilter { // TODO(lzr): Add AllowEverything() helper.
|
||||
// Whether resource variable ops are allowed are allowed in callees. We do
|
||||
// not allow resource variable ops in called functions (either as direct TF
|
||||
// calls or as higher order control flow ops) because we do not yet model
|
||||
// their memory effects in jit/resource_variable_safety_analysis.
|
||||
bool allow_resource_ops_in_called_functions;
|
||||
|
||||
// Whether Stack operations are allowed. We avoid auto-clustering Stack
|
||||
// operations in general because we do not support snapshotting them.
|
||||
//
|
||||
// TODO(b/112837194): This restriction can be lifted with some work.
|
||||
bool allow_stack_ops;
|
||||
|
||||
// Whether TensorArray operations are allowed. We avoid auto-clustering
|
||||
// TensorArray operations in general because we do not support snapshotting
|
||||
// them.
|
||||
//
|
||||
// TODO(b/112837194): This restriction can be lifted with some work.
|
||||
bool allow_tensor_array_ops;
|
||||
|
||||
// Whether stateful RNG ops are allowed. XLA's RNG does not have the same
|
||||
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
|
||||
// auto-clustering stateful RNG ops.
|
||||
bool allow_stateful_rng_ops;
|
||||
|
||||
// TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound
|
||||
// to cluster ControlTrigger because of how we use deadness analysis.
|
||||
bool allow_control_trigger;
|
||||
|
||||
// Whether it is okay to "cluster" Assert and CheckNumerics by simply
|
||||
// removing them (they're not removed during clustering, but their
|
||||
// XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so
|
||||
// that the user is not surprised when XLA is implicitly enabled. If the
|
||||
// user explicitly specifies to use XLA, it is fine to resort to a dummy
|
||||
// implementation. Currently Assert and CheckNumerics ops have dummy XLA
|
||||
// implementations.
|
||||
bool allow_eliding_assert_and_checknumerics_ops;
|
||||
|
||||
// Whether ops that produce or consume DT_VARIANT values are allowed. We
|
||||
// don't auto-cluster these ops because we don't yet support live-in or
|
||||
// live-out DT_VARIANT values.
|
||||
bool allow_ops_producing_or_consuming_variant;
|
||||
|
||||
// Whether ops known to be slow or to have correctness issues should be
|
||||
// auto-clustered.
|
||||
bool allow_slow_and_inaccurate_ops;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
|
||||
const DeviceType* jit_device_type)
|
||||
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
|
||||
|
||||
// Returns true if `node` can be compiled by XLA.
|
||||
bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) {
|
||||
return IsCompilableNode(node, /*depth=*/0, lib_runtime);
|
||||
}
|
||||
|
||||
// Returns true if `call_def` can be compiled by XLA. It is assumed that
|
||||
// `call_def` is a call operation.
|
||||
bool IsCompilableCall(const NodeDef& call_def,
|
||||
FunctionLibraryRuntime* lib_runtime) {
|
||||
return IsCompilableCall(call_def, /*depth=*/0, lib_runtime);
|
||||
}
|
||||
|
||||
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
|
||||
// due to performance or correctness concerns).
|
||||
bool OpIsInaccurate(const Node& node);
|
||||
bool OpIsSlow(const Node& node);
|
||||
|
||||
private:
|
||||
bool IsCompilableNode(const Node& node, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime);
|
||||
bool IsCompilableCall(const NodeDef& call_def, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime);
|
||||
bool IsCompilableWhile(const Node& while_node, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime);
|
||||
|
||||
bool IsStackOp(const Node& node) {
|
||||
const XlaResourceOpInfo* op_info =
|
||||
GetResourceOpInfoForOp(node.type_string());
|
||||
return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
|
||||
}
|
||||
|
||||
bool IsTensorArrayOp(const Node& node) {
|
||||
const XlaResourceOpInfo* op_info =
|
||||
GetResourceOpInfoForOp(node.type_string());
|
||||
return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
|
||||
}
|
||||
|
||||
bool IsAssertOrCheckNumerics(absl::string_view op_name) {
|
||||
return op_name == "Assert" || op_name == "CheckNumerics";
|
||||
}
|
||||
|
||||
bool IsStatefulRandomOp(absl::string_view op_name) {
|
||||
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
|
||||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
|
||||
op_name == "TruncatedNormal" || op_name == "Multinomial";
|
||||
}
|
||||
|
||||
bool OpProducesOrConsumesVariant(const Node& node) {
|
||||
auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
|
||||
return absl::c_any_of(node.input_types(), is_variant) ||
|
||||
absl::c_any_of(node.output_types(), is_variant);
|
||||
}
|
||||
|
||||
bool HasXLAKernel(const Node& node);
|
||||
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
const int kMaxRecursionDepth = 10;
|
||||
|
||||
const OperationFilter& op_filter_;
|
||||
const DeviceType& jit_device_type_;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
const XlaOpRegistry::DeviceRegistration& registration);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
@ -371,7 +371,8 @@ class PredicateFactory {
|
||||
Predicate** predicate) {
|
||||
TensorId tensor_id(node->name(), output_idx);
|
||||
|
||||
bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL;
|
||||
bool is_boolean_tensor =
|
||||
BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
|
||||
TF_RET_CHECK(!must_be_true || is_boolean_tensor);
|
||||
|
||||
if (node->type_string() == "Const" && must_be_true) {
|
||||
|
@ -1067,5 +1067,25 @@ TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output condition_ref_var =
|
||||
ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
|
||||
ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
|
||||
|
||||
Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
|
||||
Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
|
||||
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,61 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/device_info_cache.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using xla::StatusOr;
|
||||
|
||||
StatusOr<const XlaOpRegistry::DeviceRegistration*>
|
||||
DeviceInfoCache::GetCompilationDevice(absl::string_view device_name) {
|
||||
auto it = device_to_device_registration_.find(device_name);
|
||||
if (it != device_to_device_registration_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
string device_name_str = string(device_name);
|
||||
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
|
||||
GetDeviceTypeFor(device_name_str));
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
registration = nullptr;
|
||||
}
|
||||
|
||||
device_to_device_registration_.insert(
|
||||
{std::move(device_name_str), registration});
|
||||
|
||||
return registration;
|
||||
}
|
||||
|
||||
StatusOr<std::reference_wrapper<const DeviceType>>
|
||||
DeviceInfoCache::GetDeviceTypeFor(absl::string_view device_name) {
|
||||
auto it = device_to_device_type_.find(device_name);
|
||||
if (it != device_to_device_type_.end()) {
|
||||
return std::cref(*it->second);
|
||||
}
|
||||
|
||||
string device_name_str = string(device_name);
|
||||
auto device_type = absl::make_unique<DeviceType>("");
|
||||
TF_RETURN_IF_ERROR(DeviceToDeviceType(device_name_str, device_type.get()));
|
||||
|
||||
it = device_to_device_type_
|
||||
.insert({std::move(device_name_str), std::move(device_type)})
|
||||
.first;
|
||||
return std::cref(*it->second);
|
||||
}
|
||||
} // namespace tensorflow
|
@ -1,45 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Caches some miscellaneous information about TF devices. Thread compatible.
|
||||
class DeviceInfoCache {
|
||||
public:
|
||||
xla::StatusOr<const XlaOpRegistry::DeviceRegistration*> GetCompilationDevice(
|
||||
absl::string_view device_name);
|
||||
xla::StatusOr<std::reference_wrapper<const DeviceType>> GetDeviceTypeFor(
|
||||
absl::string_view device_name);
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, const XlaOpRegistry::DeviceRegistration*>
|
||||
device_to_device_registration_;
|
||||
absl::flat_hash_map<string, std::unique_ptr<DeviceType>>
|
||||
device_to_device_type_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
206
tensorflow/compiler/jit/device_util.cc
Normal file
206
tensorflow/compiler/jit/device_util.cc
Normal file
@ -0,0 +1,206 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace jit {
|
||||
using xla::StatusOr;
|
||||
|
||||
void DeviceSet::Insert(DeviceId device_id) {
|
||||
int word_index = device_id.id() / kWordSize;
|
||||
int bit_index = device_id.id() % kWordSize;
|
||||
|
||||
if (word_index >= storage_.size()) {
|
||||
storage_.resize(word_index + 1, 0);
|
||||
}
|
||||
|
||||
storage_[word_index] |= (1ull << bit_index);
|
||||
}
|
||||
|
||||
void DeviceSet::UnionWith(const DeviceSet& other) {
|
||||
if (other.storage_.size() > storage_.size()) {
|
||||
storage_.resize(other.storage_.size(), 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < other.storage_.size(); i++) {
|
||||
storage_[i] |= other.storage_[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool DeviceSet::IsEmpty() const {
|
||||
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
|
||||
}
|
||||
|
||||
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
|
||||
TF_RET_CHECK(!name.empty());
|
||||
|
||||
auto it = name_to_id_.find(name);
|
||||
if (it != name_to_id_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int new_id = names_.size();
|
||||
names_.push_back(string(name));
|
||||
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
|
||||
DeviceType* device_type = id_to_device_type_.back().get();
|
||||
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
|
||||
|
||||
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
|
||||
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
|
||||
|
||||
name_to_id_.emplace(string(name), DeviceId(new_id));
|
||||
|
||||
const XlaOpRegistry::DeviceRegistration* compilation_device;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
|
||||
&compilation_device)) {
|
||||
compilation_device = nullptr;
|
||||
}
|
||||
id_to_compilation_device_.push_back(compilation_device);
|
||||
|
||||
return DeviceId(new_id);
|
||||
}
|
||||
|
||||
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
|
||||
std::vector<string> names;
|
||||
device_set.ForEach([&](DeviceId device_id) {
|
||||
names.push_back(string(GetNameFor(device_id)));
|
||||
return false;
|
||||
});
|
||||
|
||||
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
|
||||
}
|
||||
} // namespace jit
|
||||
|
||||
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
|
||||
return errors::Internal("Malformed assigned device '", device, "'");
|
||||
}
|
||||
*device_type = DeviceType(parsed.type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
|
||||
bool failure_to_pick_is_error) {
|
||||
#define FAILED_TO_PICK_DEVICE(failing_status) \
|
||||
do { \
|
||||
if (failure_to_pick_is_error) { \
|
||||
return failing_status; \
|
||||
} else { \
|
||||
return {absl::nullopt}; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
absl::optional<jit::DeviceId> maybe_gpu_device;
|
||||
absl::optional<jit::DeviceId> maybe_cpu_device;
|
||||
absl::optional<jit::DeviceId> maybe_unknown_device;
|
||||
|
||||
bool multiple_cpu_devices = false;
|
||||
bool multiple_gpu_devices = false;
|
||||
bool multiple_unknown_devices = false;
|
||||
|
||||
devices.ForEach([&](jit::DeviceId device) {
|
||||
if (device_info_cache.IsGpu(device)) {
|
||||
if (maybe_gpu_device) {
|
||||
multiple_gpu_devices = true;
|
||||
return false;
|
||||
}
|
||||
maybe_gpu_device = device;
|
||||
} else if (device_info_cache.IsCpu(device)) {
|
||||
if (maybe_cpu_device) {
|
||||
multiple_cpu_devices = true;
|
||||
return false;
|
||||
}
|
||||
maybe_cpu_device = device;
|
||||
} else {
|
||||
if (maybe_unknown_device) {
|
||||
multiple_unknown_devices = true;
|
||||
return false;
|
||||
}
|
||||
maybe_unknown_device = device;
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
if (multiple_cpu_devices) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
|
||||
}
|
||||
|
||||
if (multiple_gpu_devices) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
|
||||
}
|
||||
|
||||
if (multiple_unknown_devices) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
|
||||
}
|
||||
|
||||
if (maybe_unknown_device && maybe_gpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Found both unknown and GPU devices: ",
|
||||
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
||||
device_info_cache.GetNameFor(*maybe_gpu_device)));
|
||||
}
|
||||
|
||||
if (!allow_mixing_unknown_and_cpu) {
|
||||
if (maybe_unknown_device && maybe_cpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Found both unknown and CPU devices: ",
|
||||
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
||||
device_info_cache.GetNameFor(*maybe_cpu_device)));
|
||||
}
|
||||
}
|
||||
|
||||
if (maybe_gpu_device) {
|
||||
return {*maybe_gpu_device};
|
||||
} else if (maybe_unknown_device) {
|
||||
return {*maybe_unknown_device};
|
||||
} else if (maybe_cpu_device) {
|
||||
return {*maybe_cpu_device};
|
||||
}
|
||||
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
|
||||
|
||||
#undef FAILED_TO_PICK_DEVICE
|
||||
}
|
||||
|
||||
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
|
||||
PickDeviceForXlaImpl(device_info_cache, devices,
|
||||
allow_mixing_unknown_and_cpu,
|
||||
/*failure_to_pick_is_error=*/true));
|
||||
return *device_id;
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||
return PickDeviceForXlaImpl(device_info_cache, devices,
|
||||
allow_mixing_unknown_and_cpu,
|
||||
/*failure_to_pick_is_error=*/false);
|
||||
}
|
||||
} // namespace tensorflow
|
211
tensorflow/compiler/jit/device_util.h
Normal file
211
tensorflow/compiler/jit/device_util.h
Normal file
@ -0,0 +1,211 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace jit {
|
||||
// Instances of DeviceId represent TensorFlow devices as integers.
|
||||
//
|
||||
// This helps avoid having to manipulate device names as strings when
|
||||
// auto-clustering.
|
||||
class DeviceId {
|
||||
public:
|
||||
DeviceId(DeviceId&&) = default;
|
||||
DeviceId(const DeviceId&) = default;
|
||||
DeviceId& operator=(const DeviceId&) = default;
|
||||
|
||||
bool operator==(const DeviceId& other) const { return id() == other.id(); }
|
||||
bool operator!=(const DeviceId& other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
int id_;
|
||||
|
||||
explicit DeviceId(int id) : id_(id) {}
|
||||
|
||||
int id() const { return id_; }
|
||||
|
||||
friend class DeviceInfoCache;
|
||||
friend class DeviceSet;
|
||||
};
|
||||
|
||||
// A set of DeviceIds, represented as a bitmap.
|
||||
class DeviceSet {
|
||||
public:
|
||||
void Insert(DeviceId device_id);
|
||||
void UnionWith(const DeviceSet& other);
|
||||
bool IsEmpty() const;
|
||||
|
||||
// Calls `func` on each DeviceId in the set. Stops iterating early if `func`
|
||||
// return false.
|
||||
//
|
||||
// TODO(sanjoy): Change this to take a typed std::function if that's
|
||||
// performance neutral.
|
||||
template <typename FnTy>
|
||||
void ForEach(FnTy func) const {
|
||||
// This is really a poor man's iterator, we should consider writing a proper
|
||||
// iterator if this ends up being used widely.
|
||||
for (int word_index = 0; word_index < storage_.size(); word_index++) {
|
||||
uint64 word = storage_[word_index];
|
||||
while (word != 0) {
|
||||
uint64 only_lowest_bit_set = word & -word;
|
||||
// The number of trailing zeros in a non-zero word is the index of the
|
||||
// least significant 1.
|
||||
int bit_index = ctz_uint64(word);
|
||||
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
|
||||
return;
|
||||
}
|
||||
word ^= only_lowest_bit_set;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static int ctz_uint64(uint64 x) {
|
||||
DCHECK_NE(x, 0);
|
||||
#ifdef __GNUC__
|
||||
return __builtin_ctzl(x);
|
||||
#else
|
||||
int result = 0u;
|
||||
while ((x & 1u) == 0u) {
|
||||
x >>= 1;
|
||||
++result;
|
||||
}
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
absl::InlinedVector<uint64, 1> storage_;
|
||||
|
||||
const int kWordSize = 64;
|
||||
};
|
||||
|
||||
// Caches some miscellaneous information about TF devices. Thread compatible.
|
||||
class DeviceInfoCache {
|
||||
public:
|
||||
bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
|
||||
bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
|
||||
|
||||
absl::string_view GetNameFor(DeviceId device) const {
|
||||
return names_[device.id()];
|
||||
}
|
||||
|
||||
xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
|
||||
|
||||
using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
|
||||
|
||||
DeviceRegistration* GetCompilationDevice(DeviceId device) const {
|
||||
return id_to_compilation_device_[device.id()];
|
||||
}
|
||||
|
||||
xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
|
||||
absl::string_view name) {
|
||||
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
|
||||
return GetCompilationDevice(device_id);
|
||||
}
|
||||
|
||||
const DeviceType& GetDeviceTypeFor(DeviceId device) const {
|
||||
return *id_to_device_type_[device.id()];
|
||||
}
|
||||
|
||||
using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
|
||||
|
||||
xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
|
||||
absl::string_view device_name) {
|
||||
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
|
||||
return std::cref(*id_to_device_type_[device_id.id()]);
|
||||
}
|
||||
|
||||
string DebugString(const DeviceSet& device_set) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, DeviceId> name_to_id_;
|
||||
|
||||
// These fields are populated for a device in GetIdFor, *before* we give out a
|
||||
// DeviceId.
|
||||
std::vector<const XlaOpRegistry::DeviceRegistration*>
|
||||
id_to_compilation_device_;
|
||||
std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
|
||||
std::vector<string> names_;
|
||||
std::vector<bool> is_cpu_;
|
||||
std::vector<bool> is_gpu_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
||||
// Returns the DeviceType corresponding to 'device'.
|
||||
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
|
||||
|
||||
// Picks the device for which XLA should compile a cluster that contains
|
||||
// operations placed in devices in `devices`. For instance a cluster that
|
||||
// contains operations solely placed on the CPU will be compiled into a CPU
|
||||
// executable by XLA, whereas a cluster that contains operations placed on the
|
||||
// CPU and also operations placed on the GPU will be compiled into a GPU
|
||||
// executable.
|
||||
//
|
||||
// Returns a non-OK Status if no unambiguous choice of device exists.
|
||||
//
|
||||
// We choose the device using the following rules:
|
||||
//
|
||||
// - It is an error for `device_names` to contain more than one device of the
|
||||
// same type.
|
||||
// - GPU is preferred over CPU.
|
||||
// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are
|
||||
// preferred over CPU.
|
||||
// - XLA devices count as "unrecognized devices".
|
||||
//
|
||||
// This set of rules above implicitly assume that XLA:GPU can compile all
|
||||
// operations in the cluster that XLA:CPU can compile, and if
|
||||
// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile
|
||||
// all operations in the cluster that XLA:CPU can compile.
|
||||
//
|
||||
// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of
|
||||
// the following things:
|
||||
//
|
||||
// - Let MarkForCompilationPass not inject CPU-placed operations into clusters
|
||||
// that will run on unknown devices (because the unknown XLA backend may not
|
||||
// support every operation supported by CPU).
|
||||
// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster
|
||||
// that contains nodes placed on both the CPU and on unknown devices. In this
|
||||
// case it is the responsibility of the optimization pass that injected the
|
||||
// CPU nodes into the cluster to ensure that these nodes can be compiled by
|
||||
// the unknown XLA backend.
|
||||
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||
|
||||
// This is like `PickDeviceForXla` except that it returns nullopt (instead of a
|
||||
// non-OK Status) if no unambiguous choice of device exists.
|
||||
//
|
||||
// We return a failing Status for errors unrelated to the device choice
|
||||
// algorithm itself.
|
||||
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
132
tensorflow/compiler/jit/device_util_test.cc
Normal file
132
tensorflow/compiler/jit/device_util_test.cc
Normal file
@ -0,0 +1,132 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
|
||||
absl::Span<const absl::string_view> device_names,
|
||||
string* result) {
|
||||
jit::DeviceInfoCache cache;
|
||||
jit::DeviceSet device_set;
|
||||
for (absl::string_view name : device_names) {
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name));
|
||||
device_set.Insert(device_id);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
jit::DeviceId result_id,
|
||||
PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu));
|
||||
*result = string(cache.GetNameFor(result_id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CheckPickDeviceResult(absl::string_view expected_result,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
absl::Span<const absl::string_view> inputs) {
|
||||
string result;
|
||||
TF_ASSERT_OK(PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result))
|
||||
<< "inputs = [" << absl::StrJoin(inputs, ", ")
|
||||
<< "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu
|
||||
<< ", expected_result=" << expected_result;
|
||||
EXPECT_EQ(result, expected_result);
|
||||
}
|
||||
|
||||
void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu,
|
||||
absl::Span<const absl::string_view> inputs) {
|
||||
string result;
|
||||
EXPECT_FALSE(
|
||||
PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result).ok());
|
||||
}
|
||||
|
||||
const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0";
|
||||
const char* kYPU0 = "/job:localhost/replica:0/task:0/device:YPU:0";
|
||||
|
||||
const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
|
||||
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
|
||||
|
||||
TEST(PickDeviceForXla, UniqueDevice) {
|
||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, DeviceOrder) {
|
||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
|
||||
CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});
|
||||
CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, MultipleUnknownDevices) {
|
||||
CheckPickDeviceHasError(false, {kXPU0, kYPU0});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, GpuAndUnknown) {
|
||||
CheckPickDeviceHasError(false, {kGPU0, kXPU1});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, UnknownAndCpu) {
|
||||
CheckPickDeviceHasError(false, {kXPU0, kCPU1});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
|
||||
CheckPickDeviceHasError(true, {kCPU0, kCPU1});
|
||||
CheckPickDeviceHasError(false, {kCPU0, kCPU1});
|
||||
CheckPickDeviceHasError(false, {kGPU0, kGPU1});
|
||||
CheckPickDeviceHasError(false, {kXPU0, kXPU1});
|
||||
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
|
||||
}
|
||||
|
||||
void SimpleRoundTripTestForDeviceSet(int num_devices) {
|
||||
jit::DeviceSet device_set;
|
||||
jit::DeviceInfoCache device_info_cache;
|
||||
|
||||
std::vector<string> expected_devices, actual_devices;
|
||||
|
||||
for (int i = 0; i < num_devices; i++) {
|
||||
string device_name =
|
||||
absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i);
|
||||
TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id,
|
||||
device_info_cache.GetIdFor(device_name));
|
||||
device_set.Insert(device_id);
|
||||
expected_devices.push_back(device_name);
|
||||
}
|
||||
|
||||
device_set.ForEach([&](jit::DeviceId device_id) {
|
||||
actual_devices.push_back(string(device_info_cache.GetNameFor(device_id)));
|
||||
return true;
|
||||
});
|
||||
|
||||
EXPECT_EQ(expected_devices, actual_devices);
|
||||
}
|
||||
|
||||
TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); }
|
||||
|
||||
TEST(DeviceSetTest, SimpleRoundTrip_Small) {
|
||||
SimpleRoundTripTestForDeviceSet(8);
|
||||
}
|
||||
|
||||
TEST(DeviceSetTest, SimpleRoundTrip_Large) {
|
||||
SimpleRoundTripTestForDeviceSet(800);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -2497,8 +2497,6 @@ Status EncapsulateSubgraphsInFunctions(
|
||||
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
|
||||
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
|
||||
FunctionLibraryDefinition* library) {
|
||||
Status s;
|
||||
|
||||
Encapsulator encapsulator(std::move(group_attribute),
|
||||
std::move(outside_compilation_attribute),
|
||||
&graph_in);
|
||||
|
@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
||||
XlaClusterInfo{func, func_name_attrs, xla_computation_node,
|
||||
std::map<string, int>{}});
|
||||
}
|
||||
bool modified;
|
||||
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
|
||||
graph_out.get(), flr, lib_def.get());
|
||||
graph_out.get(), flr, lib_def.get(), &modified);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
GraphDef graphdef_out;
|
||||
@ -1105,7 +1106,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}},
|
||||
{"F"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
@ -1985,7 +1988,9 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}}},
|
||||
},
|
||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{"h_0_retval_retval", "H:o:0"}});
|
||||
@ -2110,7 +2115,9 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
@ -2258,7 +2265,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}},
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
|
||||
{}},
|
||||
{{"outside_compilation_O3_host_compute"},
|
||||
"XlaHostCompute",
|
||||
@ -2271,7 +2279,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O3"},
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>({"_xla_token_arg_node"})}},
|
||||
absl::Span<const string>({"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute",
|
||||
"outside_compilation_O2_host_compute"})}},
|
||||
{}}},
|
||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{"h_0_retval_retval", "H:o:0"}});
|
||||
|
@ -14,9 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
@ -24,6 +27,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -333,6 +339,43 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
|
||||
OutsideCompilationClusterDependencies(
|
||||
const Graph* g, const string& outside_compilation_attr_name) {
|
||||
auto cluster_deps = absl::make_unique<
|
||||
absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
|
||||
|
||||
for (const Edge* e : g->edges()) {
|
||||
auto src_outside_compilation =
|
||||
GetStringAttr(*e->src(), outside_compilation_attr_name);
|
||||
auto dst_outside_compilation =
|
||||
GetStringAttr(*e->dst(), outside_compilation_attr_name);
|
||||
|
||||
if (src_outside_compilation && dst_outside_compilation &&
|
||||
*src_outside_compilation != *dst_outside_compilation) {
|
||||
auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
|
||||
if (dst_deps_it == cluster_deps->end()) {
|
||||
cluster_deps->insert(std::make_pair(
|
||||
*dst_outside_compilation,
|
||||
absl::flat_hash_set<string>({*src_outside_compilation})));
|
||||
} else {
|
||||
dst_deps_it->second.insert(*src_outside_compilation);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto cluster_deps_ordered =
|
||||
absl::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
|
||||
|
||||
for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
|
||||
std::vector<string> ordered_deps(it->second.begin(), it->second.end());
|
||||
std::sort(ordered_deps.begin(), ordered_deps.end());
|
||||
cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
|
||||
}
|
||||
|
||||
return std::move(cluster_deps_ordered);
|
||||
}
|
||||
|
||||
Status PreprocessEdgesBetweenOutsideCompilations(
|
||||
Graph* g, const string& outside_compilation_attr_name) {
|
||||
// Remove edges from source node to outside compilation nodes, and edges
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -89,6 +91,15 @@ struct XlaClusterInfo {
|
||||
const std::map<string, int> host_compute_core;
|
||||
};
|
||||
|
||||
// Finds dependencies between outside compilation clusters, including both data
|
||||
// dependencies and control dependencies. cluster_deps maps the name name of an
|
||||
// outside compilation cluster to a set of names of outside compilation clusters
|
||||
// that it depends on.
|
||||
stream_executor::port::StatusOr<
|
||||
std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
|
||||
OutsideCompilationClusterDependencies(
|
||||
const Graph* g, const string& outside_compilation_attr_name);
|
||||
|
||||
// Preprocesses edges within the same XLA cluster. It will perform the following
|
||||
// operations in order:
|
||||
//
|
||||
|
@ -15,12 +15,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
@ -287,15 +289,20 @@ absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
|
||||
return results;
|
||||
}
|
||||
|
||||
string host_compute_node_name(const string& original_oc_name) {
|
||||
return absl::StrCat("outside_compilation_", original_oc_name,
|
||||
"_host_compute");
|
||||
}
|
||||
|
||||
// Builds XlaHostCompute NodeDef from the outside compilation call node.
|
||||
xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
const Node* call_node, const std::map<string, int>& host_compute_core) {
|
||||
const Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
||||
string original_oc_name;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(
|
||||
call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
|
||||
NodeDefBuilder host_compute_builder(
|
||||
absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"),
|
||||
"XlaHostCompute");
|
||||
NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
|
||||
"XlaHostCompute");
|
||||
|
||||
// Copy all attributes.
|
||||
for (auto attr : call_node->attrs()) {
|
||||
@ -309,9 +316,25 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
host_compute_builder.Attr("tpu_core", core);
|
||||
}
|
||||
|
||||
// Set input tokens.
|
||||
host_compute_builder.Attr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName});
|
||||
// Set input tokens and other outside compilation clusters that current
|
||||
// cluster depends in `kXlaTokenArgNodeName`. This is needed because when
|
||||
// outside compilation subgraphs are encapsulated and moved to host graph,
|
||||
// control/data edges between them will only be reflected in host graph.
|
||||
// From XLA's perspective, two originally dependent clusters are no longer
|
||||
// connected, which makes them look like they can be scheduled for execution
|
||||
// in arbitrary order even though in fact they must be executed in order
|
||||
// according to their host-side graph dependency. This can cause deadlock.
|
||||
// Therefore, we hint XLA what the correct ordering of these clusters should
|
||||
// be to avoid deadlocks.
|
||||
std::vector<string> xla_token_input_nodes;
|
||||
xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
|
||||
auto cluster_deps_it = cluster_deps.find(original_oc_name);
|
||||
if (cluster_deps_it != cluster_deps.end()) {
|
||||
for (auto dep : cluster_deps_it->second) {
|
||||
xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
|
||||
}
|
||||
}
|
||||
host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
|
||||
|
||||
// Populate inputs.
|
||||
std::vector<DataType> input_dtypes;
|
||||
@ -371,7 +394,8 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
// If the function call node has no input/output edges, we will just remove it
|
||||
// and not create a XlaHostCompute node.
|
||||
Status ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core) {
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
||||
// If the function call node has no input/output edges, just remove it.
|
||||
bool has_edge = false;
|
||||
for (auto e : call_node->in_edges()) {
|
||||
@ -393,8 +417,9 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
}
|
||||
|
||||
// Build XlaHostCompute NodeDef.
|
||||
TF_ASSIGN_OR_RETURN(NodeDef node_def,
|
||||
BuildXlaHostComputeNodeDef(call_node, host_compute_core));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
NodeDef node_def,
|
||||
BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
|
||||
TF_ASSIGN_OR_RETURN(Node * host_compute_node,
|
||||
ReplaceNode(g, call_node, node_def));
|
||||
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
|
||||
@ -1589,6 +1614,11 @@ Status ExtractOutsideCompilationForFunction(
|
||||
// We cannot early return here, because we might have outside compilation in
|
||||
// If/While function body.
|
||||
|
||||
// Find dependencies between outside compilation clusters.
|
||||
TF_ASSIGN_OR_RETURN(auto cluster_deps,
|
||||
OutsideCompilationClusterDependencies(
|
||||
fbody->graph, outside_compilation_attr_name));
|
||||
|
||||
// Preprocess edges between different outside compilations. They will be
|
||||
// restored in `ConstructHostGraph()`.
|
||||
TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
|
||||
@ -1643,7 +1673,7 @@ Status ExtractOutsideCompilationForFunction(
|
||||
for (Node* n : outside_compilation_nodes) {
|
||||
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
||||
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
graph_out.get(), n, host_compute_core));
|
||||
graph_out.get(), n, host_compute_core, *cluster_deps));
|
||||
}
|
||||
|
||||
// Handle nodes with associated functions.
|
||||
@ -1691,11 +1721,13 @@ Status ExtractOutsideCompilation(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name,
|
||||
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
bool* modified) {
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
|
||||
}
|
||||
|
||||
*modified = false;
|
||||
auto node_name_index = g->BuildNodeNameIndex();
|
||||
for (auto& iter : clusters) {
|
||||
string xla_cluster_name = iter.first;
|
||||
@ -1711,6 +1743,7 @@ Status ExtractOutsideCompilation(
|
||||
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
|
||||
host_compute_core, flr, fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
*modified |= has_outside_compilation;
|
||||
|
||||
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
|
||||
Node* pivot_node = node_name_index[pivot_name];
|
||||
|
@ -101,7 +101,8 @@ Status ExtractOutsideCompilation(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name,
|
||||
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
bool* modified);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -922,4 +922,145 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
OutsideCompilationClusterDataDependency) {
|
||||
// Build the XLA computation func.
|
||||
// "const0"
|
||||
// "identity0" = "const0" (outside compilation cluster "0")
|
||||
// "identity1" = "identity0" (outside compilation cluster "1")
|
||||
// "identity2" = "identity1"
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
|
||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
|
||||
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
|
||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
|
||||
<< std::endl;
|
||||
auto node_name_image = g->BuildNodeNameIndex();
|
||||
node_name_image["identity0"]->AddAttr("_oc", "0");
|
||||
node_name_image["identity1"]->AddAttr("_oc", "1");
|
||||
|
||||
PartialTensorShape shape({2});
|
||||
node_name_image["identity1"]->AddAttr(
|
||||
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
|
||||
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
protobuf::Map<string, tensorflow::AttrValue> attrs;
|
||||
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
|
||||
std::vector<string> shape_inference_graphs;
|
||||
bool has_outside_compilation;
|
||||
NameAttrList name_attrs;
|
||||
name_attrs.set_name("cluster");
|
||||
*name_attrs.mutable_attr() = attrs;
|
||||
TF_CHECK_OK(ExtractOutsideCompilationTest(
|
||||
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
|
||||
host_compute_core, &fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
|
||||
// Get rewritten XLA computation function.
|
||||
std::unique_ptr<FunctionBody> xla_fbody;
|
||||
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
|
||||
AttrSlice(), &fld, &xla_fbody));
|
||||
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
|
||||
|
||||
// Check XlaHostCompute nodes.
|
||||
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
|
||||
EXPECT_NE(host_compute_0, nullptr);
|
||||
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
|
||||
EXPECT_NE(host_compute_1, nullptr);
|
||||
|
||||
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
|
||||
std::vector<string> token_input_nodes;
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
|
||||
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
|
||||
token_input_nodes.clear();
|
||||
std::vector<string> expected_token_input_nodes_1(
|
||||
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||
}
|
||||
|
||||
TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
OutsideCompilationClusterControlDependency) {
|
||||
// Build the XLA computation func.
|
||||
// "const0"
|
||||
// "identity0" = "const0" (outside compilation cluster "0")
|
||||
// "identity1" = "const0" "^identity0" (outside compilation cluster "1",
|
||||
// control depdent on cluster "0")
|
||||
// "identity2" = "identity1"
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
|
||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
|
||||
Output identity1 = ops::Identity(
|
||||
s.WithOpName("identity1").WithControlDependencies(identity0), const0);
|
||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
|
||||
<< std::endl;
|
||||
auto node_name_image = g->BuildNodeNameIndex();
|
||||
node_name_image["identity0"]->AddAttr("_oc", "0");
|
||||
node_name_image["identity1"]->AddAttr("_oc", "1");
|
||||
|
||||
PartialTensorShape shape({2});
|
||||
node_name_image["identity1"]->AddAttr(
|
||||
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
|
||||
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
protobuf::Map<string, tensorflow::AttrValue> attrs;
|
||||
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
|
||||
std::vector<string> shape_inference_graphs;
|
||||
bool has_outside_compilation;
|
||||
NameAttrList name_attrs;
|
||||
name_attrs.set_name("cluster");
|
||||
*name_attrs.mutable_attr() = attrs;
|
||||
TF_CHECK_OK(ExtractOutsideCompilationTest(
|
||||
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
|
||||
host_compute_core, &fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
|
||||
// Get rewritten XLA computation function.
|
||||
std::unique_ptr<FunctionBody> xla_fbody;
|
||||
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
|
||||
AttrSlice(), &fld, &xla_fbody));
|
||||
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
|
||||
|
||||
// Check XlaHostCompute nodes.
|
||||
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
|
||||
EXPECT_NE(host_compute_0, nullptr);
|
||||
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
|
||||
EXPECT_NE(host_compute_1, nullptr);
|
||||
|
||||
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
|
||||
std::vector<string> token_input_nodes;
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
|
||||
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
|
||||
token_input_nodes.clear();
|
||||
std::vector<string> expected_token_input_nodes_1(
|
||||
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -36,6 +36,10 @@ std::once_flag flags_init;
|
||||
|
||||
bool SetterForXlaAutoJitFlag(const string& value) {
|
||||
int32 opt_level;
|
||||
// We need to use the mark_for_compilation_flags directly here instead of
|
||||
// going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
|
||||
// latter will try to setup and parse flags, which would bring us back to this
|
||||
// setter.
|
||||
if (absl::SimpleAtoi(value, &opt_level)) {
|
||||
mark_for_compilation_flags->xla_auto_jit_flag
|
||||
.optimization_level_single_gpu = opt_level;
|
||||
@ -155,9 +159,14 @@ void AllocateAndParseFlags() {
|
||||
|
||||
} // namespace
|
||||
|
||||
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
|
||||
bool SetXlaAutoJitFlagFromFlagString(const string& value) {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return *build_ops_flags;
|
||||
return SetterForXlaAutoJitFlag(value);
|
||||
}
|
||||
|
||||
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return build_ops_flags;
|
||||
}
|
||||
|
||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
|
||||
|
@ -38,6 +38,12 @@ struct XlaAutoJitFlag {
|
||||
int32 optimization_level_general;
|
||||
};
|
||||
|
||||
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
|
||||
// is:
|
||||
// <number>: sets general and single_gpu setting to the provided number.
|
||||
// single-gpu(<number>): sets the single_gpu setting to the provided number.
|
||||
bool SetXlaAutoJitFlagFromFlagString(const string& value);
|
||||
|
||||
// Flags associated with the XLA bridge's mark_for_compilation_pass module.
|
||||
struct MarkForCompilationPassFlags {
|
||||
XlaAutoJitFlag xla_auto_jit_flag;
|
||||
@ -111,7 +117,7 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer
|
||||
// always return the same pointer.
|
||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
|
||||
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
|
||||
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
|
||||
XlaDeviceFlags* GetXlaDeviceFlags();
|
||||
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
||||
|
||||
|
@ -13,8 +13,23 @@ cc_library(
|
||||
srcs = ["graphcycles.cc"],
|
||||
hdrs = ["graphcycles.h"],
|
||||
deps = [
|
||||
":ordered_set",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ordered_set",
|
||||
hdrs = ["ordered_set.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -28,3 +43,14 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ordered_set_test",
|
||||
srcs = ["ordered_set_test.cc"],
|
||||
deps = [
|
||||
":ordered_set",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
@ -34,14 +34,20 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::unordered_set<int32> NodeSet;
|
||||
using NodeSet = absl::flat_hash_set<int32>;
|
||||
using OrderedNodeSet = OrderedSet<int32>;
|
||||
|
||||
template <typename T>
|
||||
struct VecStruct {
|
||||
typedef absl::InlinedVector<T, 4> type;
|
||||
@ -50,13 +56,11 @@ template <typename T>
|
||||
using Vec = typename VecStruct<T>::type;
|
||||
|
||||
struct Node {
|
||||
Node() : in(4), out(4) {} // Small hashtables for in/out edges
|
||||
|
||||
int32 rank; // rank number assigned by Pearce-Kelly algorithm
|
||||
bool visited; // Temporary marker used by depth-first-search
|
||||
void* data; // User-supplied data
|
||||
NodeSet in; // List of immediate predecessor nodes in graph
|
||||
NodeSet out; // List of immediate successor nodes in graph
|
||||
OrderedNodeSet in; // List of immediate predecessor nodes in graph
|
||||
OrderedNodeSet out; // List of immediate successor nodes in graph
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -93,7 +97,7 @@ bool GraphCycles::CheckInvariants() const {
|
||||
if (!ranks.insert(nx->rank).second) {
|
||||
LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
|
||||
}
|
||||
for (auto y : nx->out) {
|
||||
for (int32 y : nx->out.GetSequence()) {
|
||||
Node* ny = r->nodes_[y];
|
||||
if (nx->rank >= ny->rank) {
|
||||
LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
|
||||
@ -124,14 +128,14 @@ int32 GraphCycles::NewNode() {
|
||||
|
||||
void GraphCycles::RemoveNode(int32 node) {
|
||||
Node* x = rep_->nodes_[node];
|
||||
for (auto y : x->out) {
|
||||
rep_->nodes_[y]->in.erase(node);
|
||||
for (int32 y : x->out.GetSequence()) {
|
||||
rep_->nodes_[y]->in.Erase(node);
|
||||
}
|
||||
for (auto y : x->in) {
|
||||
rep_->nodes_[y]->out.erase(node);
|
||||
for (int32 y : x->in.GetSequence()) {
|
||||
rep_->nodes_[y]->out.Erase(node);
|
||||
}
|
||||
x->in.clear();
|
||||
x->out.clear();
|
||||
x->in.Clear();
|
||||
x->out.Clear();
|
||||
rep_->free_nodes_.push_back(node);
|
||||
}
|
||||
|
||||
@ -144,12 +148,12 @@ void GraphCycles::SetNodeData(int32 node, void* data) {
|
||||
}
|
||||
|
||||
bool GraphCycles::HasEdge(int32 x, int32 y) const {
|
||||
return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end();
|
||||
return rep_->nodes_[x]->out.Contains(y);
|
||||
}
|
||||
|
||||
void GraphCycles::RemoveEdge(int32 x, int32 y) {
|
||||
rep_->nodes_[x]->out.erase(y);
|
||||
rep_->nodes_[y]->in.erase(x);
|
||||
rep_->nodes_[x]->out.Erase(y);
|
||||
rep_->nodes_[y]->in.Erase(x);
|
||||
// No need to update the rank assignment since a previous valid
|
||||
// rank assignment remains valid after an edge deletion.
|
||||
}
|
||||
@ -165,13 +169,13 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
|
||||
if (x == y) return false;
|
||||
Rep* r = rep_;
|
||||
Node* nx = r->nodes_[x];
|
||||
if (!nx->out.insert(y).second) {
|
||||
if (!nx->out.Insert(y)) {
|
||||
// Edge already exists.
|
||||
return true;
|
||||
}
|
||||
|
||||
Node* ny = r->nodes_[y];
|
||||
ny->in.insert(x);
|
||||
ny->in.Insert(x);
|
||||
|
||||
if (nx->rank <= ny->rank) {
|
||||
// New edge is consistent with existing rank assignment.
|
||||
@ -182,8 +186,8 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
|
||||
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
|
||||
if (!ForwardDFS(r, y, nx->rank)) {
|
||||
// Found a cycle. Undo the insertion and tell caller.
|
||||
nx->out.erase(y);
|
||||
ny->in.erase(x);
|
||||
nx->out.Erase(y);
|
||||
ny->in.Erase(x);
|
||||
// Since we do not call Reorder() on this path, clear any visited
|
||||
// markers left by ForwardDFS.
|
||||
ClearVisitedBits(r, r->deltaf_);
|
||||
@ -209,7 +213,7 @@ static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
|
||||
nn->visited = true;
|
||||
r->deltaf_.push_back(n);
|
||||
|
||||
for (auto w : nn->out) {
|
||||
for (auto w : nn->out.GetSequence()) {
|
||||
Node* nw = r->nodes_[w];
|
||||
if (nw->rank == upper_bound) {
|
||||
return false; // Cycle
|
||||
@ -235,7 +239,7 @@ static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
|
||||
nn->visited = true;
|
||||
r->deltab_.push_back(n);
|
||||
|
||||
for (auto w : nn->in) {
|
||||
for (auto w : nn->in.GetSequence()) {
|
||||
Node* nw = r->nodes_[w];
|
||||
if (!nw->visited && lower_bound < nw->rank) {
|
||||
r->stack_.push_back(w);
|
||||
@ -321,7 +325,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
|
||||
return path_len;
|
||||
}
|
||||
|
||||
for (auto w : r->nodes_[n]->out) {
|
||||
for (auto w : r->nodes_[n]->out.GetSequence()) {
|
||||
if (seen.insert(w).second) {
|
||||
r->stack_.push_back(w);
|
||||
}
|
||||
@ -375,31 +379,94 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
}
|
||||
|
||||
Node* nb = rep_->nodes_[b];
|
||||
std::unordered_set<int32> out = std::move(nb->out);
|
||||
std::unordered_set<int32> in = std::move(nb->in);
|
||||
for (auto y : out) {
|
||||
rep_->nodes_[y]->in.erase(b);
|
||||
OrderedNodeSet out = std::move(nb->out);
|
||||
OrderedNodeSet in = std::move(nb->in);
|
||||
for (int32 y : out.GetSequence()) {
|
||||
rep_->nodes_[y]->in.Erase(b);
|
||||
}
|
||||
for (auto y : in) {
|
||||
rep_->nodes_[y]->out.erase(b);
|
||||
for (int32 y : in.GetSequence()) {
|
||||
rep_->nodes_[y]->out.Erase(b);
|
||||
}
|
||||
rep_->free_nodes_.push_back(b);
|
||||
|
||||
for (auto y : out) {
|
||||
rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size());
|
||||
for (int32 y : out.GetSequence()) {
|
||||
InsertEdge(a, y);
|
||||
}
|
||||
for (auto y : in) {
|
||||
|
||||
rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size());
|
||||
for (int32 y : in.GetSequence()) {
|
||||
InsertEdge(y, a);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<int32> GraphCycles::Successors(int32 node) const {
|
||||
return rep_->nodes_[node]->out;
|
||||
absl::Span<const int32> GraphCycles::Successors(int32 node) const {
|
||||
return rep_->nodes_[node]->out.GetSequence();
|
||||
}
|
||||
|
||||
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) const {
|
||||
return rep_->nodes_[node]->in;
|
||||
absl::Span<const int32> GraphCycles::Predecessors(int32 node) const {
|
||||
return rep_->nodes_[node]->in.GetSequence();
|
||||
}
|
||||
|
||||
std::vector<int32> GraphCycles::SuccessorsCopy(int32 node) const {
|
||||
absl::Span<const int32> successors = Successors(node);
|
||||
return std::vector<int32>(successors.begin(), successors.end());
|
||||
}
|
||||
|
||||
std::vector<int32> GraphCycles::PredecessorsCopy(int32 node) const {
|
||||
absl::Span<const int32> predecessors = Predecessors(node);
|
||||
return std::vector<int32>(predecessors.begin(), predecessors.end());
|
||||
}
|
||||
|
||||
namespace {
|
||||
void SortInPostOrder(absl::Span<Node* const> nodes,
|
||||
std::vector<int32>* to_sort) {
|
||||
absl::c_sort(*to_sort, [&](int32 a, int32 b) {
|
||||
DCHECK(a == b || nodes[a]->rank != nodes[b]->rank);
|
||||
return nodes[a]->rank > nodes[b]->rank;
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<int32> GraphCycles::AllNodesInPostOrder() const {
|
||||
absl::flat_hash_set<int32> free_nodes_set;
|
||||
absl::c_copy(rep_->free_nodes_,
|
||||
std::inserter(free_nodes_set, free_nodes_set.begin()));
|
||||
|
||||
std::vector<int32> all_nodes;
|
||||
all_nodes.reserve(rep_->nodes_.size() - free_nodes_set.size());
|
||||
for (int64 i = 0, e = rep_->nodes_.size(); i < e; i++) {
|
||||
if (!free_nodes_set.contains(i)) {
|
||||
all_nodes.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
SortInPostOrder(rep_->nodes_, &all_nodes);
|
||||
return all_nodes;
|
||||
}
|
||||
|
||||
string GraphCycles::DebugString() const {
|
||||
absl::flat_hash_set<int32> free_nodes_set;
|
||||
for (int32 free_node : rep_->free_nodes_) {
|
||||
free_nodes_set.insert(free_node);
|
||||
}
|
||||
|
||||
string result = "digraph {\n";
|
||||
for (int i = 0; i < rep_->nodes_.size(); i++) {
|
||||
if (free_nodes_set.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int32 succ : rep_->nodes_[i]->out.GetSequence()) {
|
||||
absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n");
|
||||
}
|
||||
}
|
||||
|
||||
absl::StrAppend(&result, "}\n");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
// GraphCycles detects the introduction of a cycle into a directed
|
||||
// graph that is being built up incrementally.
|
||||
//
|
||||
@ -38,8 +40,7 @@ limitations under the License.
|
||||
// FindPath() is linear in the size of the graph.
|
||||
// The current implementation uses O(|V|+|E|) space.
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -117,8 +118,26 @@ class GraphCycles {
|
||||
// Expensive: should only be called from graphcycles_test.cc.
|
||||
bool CheckInvariants() const;
|
||||
|
||||
std::unordered_set<int32> Successors(int32 node) const;
|
||||
std::unordered_set<int32> Predecessors(int32 node) const;
|
||||
// Warning: Do not use these if iterating over the span and modifying the
|
||||
// GraphCycles at the same time. Instead use SuccessorsCopy/PredecessorsCopy.
|
||||
absl::Span<const int32> Successors(int32 node) const;
|
||||
absl::Span<const int32> Predecessors(int32 node) const;
|
||||
|
||||
// Return a copy of the sucessors set. This is needed for code using the
|
||||
// collection while modifying the GraphCycles.
|
||||
std::vector<int32> SuccessorsCopy(int32 node) const;
|
||||
// Return a copy of the predecessors set. This is needed for code using the
|
||||
// collection while modifying the GraphCycles.
|
||||
std::vector<int32> PredecessorsCopy(int32 node) const;
|
||||
|
||||
// Returns all nodes in post order.
|
||||
//
|
||||
// If there is a path from X to Y then X appears after Y in the
|
||||
// returned vector.
|
||||
std::vector<int32> AllNodesInPostOrder() const;
|
||||
|
||||
// Returns the graph in graphviz format.
|
||||
string DebugString() const;
|
||||
|
||||
// ----------------------------------------------------
|
||||
struct Rep;
|
||||
|
85
tensorflow/compiler/jit/graphcycles/ordered_set.h
Normal file
85
tensorflow/compiler/jit/graphcycles/ordered_set.h
Normal file
@ -0,0 +1,85 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// This is a set data structure that provides a deterministic iteration order.
|
||||
// The iteration order of elements only depends on the sequence of
|
||||
// inserts/deletes, so as long as the inserts/deletes happen in the same
|
||||
// sequence, the set will have the same iteration order.
|
||||
//
|
||||
// Assumes that T can be cheaply copied for simplicity.
|
||||
template <typename T>
|
||||
class OrderedSet {
|
||||
public:
|
||||
// Inserts `value` into the ordered set. Returns true if the value was not
|
||||
// present in the set before the insertion.
|
||||
bool Insert(T value) {
|
||||
bool new_insertion =
|
||||
value_to_index_.insert({value, value_sequence_.size()}).second;
|
||||
if (new_insertion) {
|
||||
value_sequence_.push_back(value);
|
||||
}
|
||||
return new_insertion;
|
||||
}
|
||||
|
||||
// Removes `value` from the set. Assumes `value` is already present in the
|
||||
// set.
|
||||
void Erase(T value) {
|
||||
auto it = value_to_index_.find(value);
|
||||
DCHECK(it != value_to_index_.end());
|
||||
|
||||
// Since we don't want to move values around in `value_sequence_` we swap
|
||||
// the value in the last position and with value to be deleted and then
|
||||
// pop_back.
|
||||
value_to_index_[value_sequence_.back()] = it->second;
|
||||
std::swap(value_sequence_[it->second], value_sequence_.back());
|
||||
value_sequence_.pop_back();
|
||||
value_to_index_.erase(it);
|
||||
}
|
||||
|
||||
void Reserve(size_t new_size) {
|
||||
value_to_index_.reserve(new_size);
|
||||
value_sequence_.reserve(new_size);
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
value_to_index_.clear();
|
||||
value_sequence_.clear();
|
||||
}
|
||||
|
||||
bool Contains(T value) const { return value_to_index_.contains(value); }
|
||||
size_t Size() const { return value_sequence_.size(); }
|
||||
|
||||
absl::Span<T const> GetSequence() const { return value_sequence_; }
|
||||
|
||||
private:
|
||||
// The stable order that we maintain through insertions and deletions.
|
||||
std::vector<T> value_sequence_;
|
||||
|
||||
// Maps values to their indices in `value_sequence_`.
|
||||
absl::flat_hash_map<T, int> value_to_index_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
117
tensorflow/compiler/jit/graphcycles/ordered_set_test.cc
Normal file
117
tensorflow/compiler/jit/graphcycles/ordered_set_test.cc
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
TEST(OrderedSetTest, Insert) {
|
||||
OrderedSet<int> ordered_set;
|
||||
EXPECT_TRUE(ordered_set.Insert(90));
|
||||
EXPECT_TRUE(ordered_set.Insert(100));
|
||||
EXPECT_TRUE(ordered_set.Insert(80));
|
||||
|
||||
EXPECT_FALSE(ordered_set.Insert(100));
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 3);
|
||||
|
||||
EXPECT_TRUE(ordered_set.Contains(90));
|
||||
EXPECT_TRUE(ordered_set.Contains(100));
|
||||
EXPECT_TRUE(ordered_set.Contains(80));
|
||||
|
||||
EXPECT_FALSE(ordered_set.Contains(40));
|
||||
|
||||
std::array<int, 3> expected_sequence = {90, 100, 80};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
|
||||
}
|
||||
|
||||
TEST(OrderedSetTest, Erase) {
|
||||
OrderedSet<int> ordered_set;
|
||||
EXPECT_TRUE(ordered_set.Insert(90));
|
||||
EXPECT_TRUE(ordered_set.Insert(100));
|
||||
EXPECT_TRUE(ordered_set.Insert(80));
|
||||
|
||||
ordered_set.Erase(100);
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 2);
|
||||
|
||||
EXPECT_TRUE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_TRUE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 2> expected_sequence_0 = {90, 80};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_0);
|
||||
|
||||
ordered_set.Erase(80);
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 1);
|
||||
|
||||
EXPECT_TRUE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_FALSE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 1> expected_sequence_1 = {90};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_1);
|
||||
|
||||
ordered_set.Erase(90);
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 0);
|
||||
|
||||
EXPECT_FALSE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_FALSE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 0> expected_sequence_2 = {};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_2);
|
||||
}
|
||||
|
||||
TEST(OrderedSetTest, Clear) {
|
||||
OrderedSet<int> ordered_set;
|
||||
EXPECT_TRUE(ordered_set.Insert(90));
|
||||
EXPECT_TRUE(ordered_set.Insert(100));
|
||||
EXPECT_TRUE(ordered_set.Insert(80));
|
||||
|
||||
ordered_set.Clear();
|
||||
|
||||
EXPECT_EQ(ordered_set.Size(), 0);
|
||||
|
||||
EXPECT_FALSE(ordered_set.Contains(90));
|
||||
EXPECT_FALSE(ordered_set.Contains(100));
|
||||
EXPECT_FALSE(ordered_set.Contains(80));
|
||||
|
||||
std::array<int, 0> expected_sequence = {};
|
||||
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
|
||||
}
|
||||
|
||||
TEST(OrderedSetTest, LargeInsertions) {
|
||||
const int kSize = 50 * 9000;
|
||||
|
||||
OrderedSet<int> ordered_set;
|
||||
|
||||
for (int i = 0; i < kSize; i++) {
|
||||
EXPECT_TRUE(ordered_set.Insert(i + 500));
|
||||
}
|
||||
|
||||
for (int i = 0; i < kSize; i++) {
|
||||
EXPECT_EQ(ordered_set.GetSequence()[i], i + 500);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -62,7 +62,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
std::unique_ptr<XlaAllocator> xla_allocator;
|
||||
xla::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
|
@ -40,7 +40,7 @@ class XlaPlatformInfo {
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
std::unique_ptr<XlaAllocator> xla_allocator,
|
||||
xla::DeviceMemoryAllocator* device_allocator)
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
xla_device_metadata_(xla_device_metadata),
|
||||
@ -55,7 +55,7 @@ class XlaPlatformInfo {
|
||||
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
|
||||
}
|
||||
|
||||
xla::DeviceMemoryAllocator* allocator() const {
|
||||
se::DeviceMemoryAllocator* allocator() const {
|
||||
return device_allocator_ ? device_allocator_ : xla_allocator_.get();
|
||||
}
|
||||
DeviceType device_type() const { return device_type_; }
|
||||
@ -86,7 +86,7 @@ class XlaPlatformInfo {
|
||||
// then device_allocator_ is null and xla_allocator_ points to an appropriate
|
||||
// XlaAllocator instance.
|
||||
std::unique_ptr<XlaAllocator> xla_allocator_;
|
||||
xla::DeviceMemoryAllocator* device_allocator_;
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
};
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -270,11 +270,11 @@ TEST(XlaCompilationTest, FunctionCalls) {
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_EQ(2, clusters.size());
|
||||
EXPECT_FALSE(clusters["B"].empty());
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
EXPECT_FALSE(clusters["C"].empty());
|
||||
EXPECT_EQ(clusters["C"], clusters["E"]);
|
||||
EXPECT_TRUE(clusters.find("A") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("B") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("E") == clusters.cend());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
|
||||
@ -332,31 +332,6 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
|
||||
EXPECT_NE(clusters["A"], "");
|
||||
}
|
||||
|
||||
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
|
||||
// cluster. This is partially to work around b/26800664, and partially because
|
||||
// we should probably prefer to compile metadata operators with their producers
|
||||
// wherever possible, rather than their consumers.
|
||||
TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
// While all of the following ops are notionally compilable, none is
|
||||
// permitted
|
||||
// to start a cluster. So nothing should be compiled.
|
||||
Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
|
||||
Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
|
||||
Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
|
||||
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
|
||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
auto clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
|
||||
}
|
||||
|
||||
static Status GradForUnaryCwise(FunctionDef* g,
|
||||
std::vector<FunctionDefHelper::Node> nodes) {
|
||||
for (auto& n : nodes) {
|
||||
@ -1137,6 +1112,45 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) {
|
||||
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
|
||||
// This is similar to the 'DontClusterMergingNodes' above, except
|
||||
// MatMulCombined is placed on the CPU.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
|
||||
absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
|
||||
absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
|
||||
ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
|
||||
Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
|
||||
ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
|
||||
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
|
||||
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
|
||||
|
||||
Output combined =
|
||||
ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
|
||||
n->set_assigned_device_name(string(xla_cpu_dev0));
|
||||
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
|
||||
n->set_assigned_device_name(string(xla_gpu_dev0));
|
||||
} else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
|
||||
n->set_assigned_device_name(string(xla_gpu_dev1));
|
||||
}
|
||||
}
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
// Each of the MatMuls should be in a separate cluster.
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
|
||||
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
|
||||
EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
|
||||
EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
|
||||
EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
|
||||
}
|
||||
|
||||
// TODO(b/117085735): This form of clustering should be prevented.
|
||||
TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
|
||||
// MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
|
||||
@ -1534,5 +1548,59 @@ TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
|
||||
EXPECT_EQ(clusters["test/z"], "");
|
||||
}
|
||||
|
||||
// Note, this relies on other implementation details to test the
|
||||
// specific heuristic we care about here, so other changes might be at fault if
|
||||
// this CL breaks. What we care about is that if a ShapeConsumingOp can be
|
||||
// connected with a producer or consumer and cannot be clustered with both, it
|
||||
// should be clustered with the producer.
|
||||
TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
|
||||
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
|
||||
|
||||
Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
|
||||
Output y = ops::Size(root.WithOpName("test/y"), x);
|
||||
Output z = ops::Add(root.WithOpName("test/z"), y, y);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
// Ensure that the "Size" op can only be clustered with either the producer or
|
||||
// consumer by putting them on different devices.
|
||||
FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
|
||||
FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
|
||||
FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU1);
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_NE(clusters["test/y"], "");
|
||||
EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
|
||||
EXPECT_NE(clusters["test/z"], clusters["test/y"]);
|
||||
}
|
||||
|
||||
// Test that ShapeConsuming ops are still fully clustered whenever possible.
|
||||
TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
|
||||
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
|
||||
|
||||
Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
|
||||
Output y = ops::Size(root.WithOpName("test/y"), x);
|
||||
Output z = ops::Add(root.WithOpName("test/z"), y, y);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_NE(clusters["test/y"], "");
|
||||
EXPECT_EQ(clusters["test/y"], clusters["test/x"]);
|
||||
EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -14,9 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -49,6 +51,15 @@ Status FindNodesToDecluster(const Graph& graph,
|
||||
continue;
|
||||
}
|
||||
|
||||
// Assume the benefit of not outputting a larger tensor outweighs the
|
||||
// benefit of this check.
|
||||
// TODO(tpopp): Only apply this if the value being consumed is not output
|
||||
// from the cluster to another consumer.
|
||||
// TODO(tpopp): See if XlaRun can be modified to avoid this issue
|
||||
// completely.
|
||||
if (IsShapeConsumerOp(*n)) {
|
||||
continue;
|
||||
}
|
||||
// We assume the only XLA-auto-clusterable operations with side effects are
|
||||
// resource variable updates. We can't execute these twice.
|
||||
if (HasResourceInputOrOutput(*n)) {
|
||||
@ -57,7 +68,7 @@ Status FindNodesToDecluster(const Graph& graph,
|
||||
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceToDeviceType(n->assigned_device_name(), &device_type));
|
||||
DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
|
||||
TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
|
||||
n->def(), &input_mtypes,
|
||||
&output_mtypes));
|
||||
@ -77,8 +88,8 @@ Status FindNodesToDecluster(const Graph& graph,
|
||||
} else {
|
||||
MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
|
||||
DeviceType dst_device_type("");
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type));
|
||||
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(dst->assigned_device_name(),
|
||||
&dst_device_type));
|
||||
TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
|
||||
dst->def(), &dst_input_mtypes,
|
||||
&dst_output_mtypes));
|
||||
@ -237,7 +248,7 @@ bool IsMustCompileDevice(const DeviceType& device_type) {
|
||||
Status MustCompileNode(const Node* n, bool* must_compile) {
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceToDeviceType(n->assigned_device_name(), &device_type));
|
||||
DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
|
||||
|
||||
if (IsMustCompileDevice(device_type)) {
|
||||
*must_compile = true;
|
||||
@ -340,6 +351,40 @@ Status PartiallyDeclusterGraph(Graph* graph,
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace reduce_recompilation
|
||||
|
||||
namespace decluster_root_shape_consumers {
|
||||
|
||||
Status PartiallyDeclusterGraph(Graph* graph) {
|
||||
std::vector<Node*> reverse_post_order;
|
||||
GetReversePostOrder(*graph, &reverse_post_order,
|
||||
/*stable_comparator=*/NodeComparatorName(),
|
||||
/*edge_filter=*/NotBackedge);
|
||||
|
||||
for (Node* n : reverse_post_order) {
|
||||
if (!IsShapeConsumerOp(*n)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
absl::optional<absl::string_view> cluster = GetXlaClusterForNode(*n);
|
||||
if (!cluster.has_value()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input_belongs_to_same_cluster = [&](const Edge* e) {
|
||||
return cluster == GetXlaClusterForNode(*e->src());
|
||||
};
|
||||
|
||||
if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(2) << "Declustering " << n->name()
|
||||
<< " because it is a root shape consumer";
|
||||
RemoveFromXlaCluster(n);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace decluster_root_shape_consumers
|
||||
} // namespace
|
||||
|
||||
Status PartiallyDeclusterPass::Run(
|
||||
@ -367,6 +412,9 @@ Status PartiallyDeclusterPass::Run(
|
||||
TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(
|
||||
graph, options.flib_def, options.session_options->env));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
decluster_root_shape_consumers::PartiallyDeclusterGraph(graph));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -40,20 +40,20 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
REGISTER_OP("FakeNullary").Output("out: float");
|
||||
REGISTER_OP("FakeNullary").Output("out: int32");
|
||||
|
||||
REGISTER_OP("FakeBinary")
|
||||
.Input("host_in: float")
|
||||
.Input("device_in: float")
|
||||
.Output("host_out: float")
|
||||
.Output("device_out: float");
|
||||
.Input("host_in: int32")
|
||||
.Input("device_in: int32")
|
||||
.Output("host_out: int32")
|
||||
.Output("device_out: int32");
|
||||
|
||||
REGISTER_OP("FakeResourceVar").Output("out: resource");
|
||||
|
||||
REGISTER_OP("FakeResourceUpdate")
|
||||
.Input("in: resource")
|
||||
.Output("out: resource")
|
||||
.Output("something_else: float");
|
||||
.Output("something_else: int32");
|
||||
|
||||
class FakeBinaryOp : public OpKernel {
|
||||
public:
|
||||
@ -467,5 +467,61 @@ TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
|
||||
EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr);
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, MetadataOpsDontStartClusters) {
|
||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||
tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
|
||||
|
||||
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
|
||||
Output b = ops::Shape(in_cluster_and.WithOpName("b"), a);
|
||||
Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
|
||||
Output d = ops::Size(in_cluster_and.WithOpName("d"), c);
|
||||
(void)ops::Shape(in_cluster_and.WithOpName("e"), d);
|
||||
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
|
||||
Node* n_b = FindNodeByName(*graph, "b");
|
||||
ASSERT_NE(n_b, nullptr);
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n_b), absl::nullopt);
|
||||
|
||||
Node* n_c = FindNodeByName(*graph, "c");
|
||||
ASSERT_NE(n_c, nullptr);
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n_c), absl::nullopt);
|
||||
|
||||
Node* n_d = FindNodeByName(*graph, "d");
|
||||
ASSERT_NE(n_d, nullptr);
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n_d), absl::nullopt);
|
||||
|
||||
Node* n_e = FindNodeByName(*graph, "e");
|
||||
ASSERT_NE(n_e, nullptr);
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n_e), absl::nullopt);
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, MetaConsumersArentDeclustered) {
|
||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||
tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
|
||||
Output b = ops::Add(in_cluster_and.WithOpName("b"), a, a);
|
||||
Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
|
||||
|
||||
Output e;
|
||||
TF_ASSERT_OK(
|
||||
CreateOutputWithScope("FakeBinary", {c, c}, root.WithOpName("e"), &e));
|
||||
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
|
||||
Node* n_b = FindNodeByName(*graph, "b");
|
||||
ASSERT_NE(n_b, nullptr);
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n_b), "cluster_0");
|
||||
|
||||
Node* n_c = FindNodeByName(*graph, "c");
|
||||
ASSERT_NE(n_c, nullptr);
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
238
tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc
Normal file
238
tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc
Normal file
@ -0,0 +1,238 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
// Function for StatefulPartitionedCall's "f", If's
|
||||
// "then_branch"/"else_branch".
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
|
||||
// "ret0" = "arg1"
|
||||
// "ret1" = "arg0"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
|
||||
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
|
||||
}
|
||||
{
|
||||
// Function for While's "body".
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
|
||||
// "ret0" = "arg0"
|
||||
// "ret1" = "arg1"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0);
|
||||
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 0);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
|
||||
}
|
||||
{
|
||||
// Function for While's "cond".
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
|
||||
// "ret0" = "arg1"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
// Build the XLA computation graph.
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
|
||||
// "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
|
||||
// "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
|
||||
NameAttrList f;
|
||||
f.set_name("f1");
|
||||
auto if_op = ops::If(s.WithOpName("if"), arg1,
|
||||
std::initializer_list<Input>{arg0, arg1},
|
||||
{DT_BOOL, DT_RESOURCE}, f, f);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
|
||||
auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("f3");
|
||||
body_fn.set_name("f2");
|
||||
auto while_op =
|
||||
ops::While(s.WithOpName("while"),
|
||||
std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
|
||||
auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
|
||||
auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
|
||||
std::vector<std::unique_ptr<FunctionBody>> fbodies;
|
||||
TF_CHECK_OK(RearrangeFunctionArguments(
|
||||
[&](const NameAttrList &function, const FunctionBody **fbody) {
|
||||
std::unique_ptr<FunctionBody> new_fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
|
||||
AttrSlice(&function.attr()),
|
||||
&fld, &new_fbody));
|
||||
*fbody = new_fbody.get();
|
||||
fbodies.push_back(std::move(new_fbody));
|
||||
return Status::OK();
|
||||
},
|
||||
g.get(), &fld));
|
||||
|
||||
// Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE}
|
||||
// and output types should be {DT_BOOL}.
|
||||
const FunctionDef *f1_rewritten = fld.Find("f1_rearrange_0");
|
||||
CHECK_NE(f1_rewritten, nullptr);
|
||||
ASSERT_EQ(f1_rewritten->signature().input_arg_size(), 2);
|
||||
EXPECT_EQ(f1_rewritten->signature().input_arg(0).type(), DT_BOOL);
|
||||
EXPECT_EQ(f1_rewritten->signature().input_arg(1).type(), DT_RESOURCE);
|
||||
ASSERT_EQ(f1_rewritten->signature().output_arg_size(), 1);
|
||||
EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL);
|
||||
|
||||
// Check node "if" input and output edges.
|
||||
auto node_name_index = g->BuildNodeNameIndex();
|
||||
const Node *if_node = node_name_index.at("if");
|
||||
ASSERT_NE(if_node, nullptr);
|
||||
const Node *input_node;
|
||||
TF_CHECK_OK(if_node->input_node(1, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg1");
|
||||
TF_CHECK_OK(if_node->input_node(2, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
const Node *ret0_node = node_name_index.at("ret0");
|
||||
ASSERT_NE(ret0_node, nullptr);
|
||||
TF_CHECK_OK(ret0_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "if");
|
||||
const Node *ret1_node = node_name_index.at("ret1");
|
||||
ASSERT_NE(ret1_node, nullptr);
|
||||
TF_CHECK_OK(ret1_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
|
||||
// Check node "while" input and output edges.
|
||||
const Node *while_node = node_name_index.at("while");
|
||||
ASSERT_NE(while_node, nullptr);
|
||||
TF_CHECK_OK(while_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg1");
|
||||
TF_CHECK_OK(while_node->input_node(1, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
const Node *ret2_node = node_name_index.at("ret2");
|
||||
ASSERT_NE(ret2_node, nullptr);
|
||||
TF_CHECK_OK(ret2_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
const Node *ret3_node = node_name_index.at("ret3");
|
||||
ASSERT_NE(ret3_node, nullptr);
|
||||
TF_CHECK_OK(ret3_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "while");
|
||||
}
|
||||
|
||||
TEST(RearrangeFunctionArgumentForFunctionTest,
|
||||
WhileResourceRetvalFromDifferentArgUnimplemented) {
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
// Function for While's "body".
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "ret0" = "arg1"
|
||||
// "ret1" = "arg0"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
|
||||
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
|
||||
auto ret2 = ops::_Retval(s.WithOpName("ret2"), arg2, 2);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
|
||||
}
|
||||
{
|
||||
// Function for While's "cond".
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "ret0" = true
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
Output cond = ops::Const(s.WithOpName("const"), true, TensorShape({}));
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), cond, 0);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
// Build the XLA computation graph.
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "arg0", "arg1" -> "while" (While)
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("f1");
|
||||
body_fn.set_name("f2");
|
||||
auto while_op = ops::While(s.WithOpName("while"),
|
||||
std::initializer_list<Input>{arg0, arg1, arg2},
|
||||
cond_fn, body_fn);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
|
||||
std::vector<std::unique_ptr<FunctionBody>> fbodies;
|
||||
Status status = RearrangeFunctionArguments(
|
||||
[&](const NameAttrList &function, const FunctionBody **fbody) {
|
||||
std::unique_ptr<FunctionBody> new_fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
|
||||
AttrSlice(&function.attr()),
|
||||
&fld, &new_fbody));
|
||||
*fbody = new_fbody.get();
|
||||
fbodies.push_back(std::move(new_fbody));
|
||||
return Status::OK();
|
||||
},
|
||||
g.get(), &fld);
|
||||
EXPECT_EQ(status.code(), error::UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -84,15 +84,6 @@ bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
|
||||
|
||||
} // namespace
|
||||
|
||||
Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
|
||||
return errors::Internal("Malformed assigned device '", device, "'");
|
||||
}
|
||||
*device_type = DeviceType(parsed.type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HasForwardedRefInput(const Node& node) {
|
||||
if (AlwaysForwardsRefInput(node)) {
|
||||
for (const Edge* incoming_edge : node.in_edges()) {
|
||||
@ -226,108 +217,6 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
|
||||
|
||||
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
|
||||
|
||||
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device,
|
||||
string* out_device_picked) {
|
||||
if (out_can_pick_device) {
|
||||
*out_can_pick_device = true;
|
||||
}
|
||||
|
||||
#define FAILED_TO_PICK_DEVICE(failing_status) \
|
||||
do { \
|
||||
if (out_can_pick_device) { \
|
||||
*out_can_pick_device = false; \
|
||||
return Status::OK(); \
|
||||
} else { \
|
||||
return failing_status; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
TF_RET_CHECK(!device_names.empty()) << "No devices to choose from";
|
||||
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
|
||||
|
||||
absl::flat_hash_set<absl::string_view> device_names_set;
|
||||
for (absl::string_view device_name : device_names) {
|
||||
if (!device_name.empty()) {
|
||||
device_names_set.insert(device_name);
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<absl::string_view> maybe_gpu_device;
|
||||
absl::optional<absl::string_view> maybe_cpu_device;
|
||||
absl::optional<absl::string_view> maybe_unknown_device;
|
||||
|
||||
for (absl::string_view device_name : device_names_set) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
|
||||
<< device_name;
|
||||
if (parsed_name.type == "GPU") {
|
||||
if (maybe_gpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple GPU devices ", absl::StrJoin(device_names, ", ")));
|
||||
}
|
||||
maybe_gpu_device = device_name;
|
||||
} else if (parsed_name.type == "CPU") {
|
||||
if (maybe_cpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple CPU devices ", absl::StrJoin(device_names, ", ")));
|
||||
}
|
||||
maybe_cpu_device = device_name;
|
||||
} else {
|
||||
if (maybe_unknown_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple unknown devices ", absl::StrJoin(device_names, ", ")));
|
||||
}
|
||||
maybe_unknown_device = device_name;
|
||||
}
|
||||
}
|
||||
|
||||
if (maybe_unknown_device && maybe_gpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Found both unknown and GPU devices: ", *maybe_unknown_device, ", ",
|
||||
*maybe_gpu_device));
|
||||
}
|
||||
|
||||
if (!allow_mixing_unknown_and_cpu) {
|
||||
if (maybe_unknown_device && maybe_cpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Found both unknown and CPU devices: ", *maybe_unknown_device, ", ",
|
||||
*maybe_cpu_device));
|
||||
}
|
||||
}
|
||||
|
||||
if (out_device_picked) {
|
||||
if (maybe_gpu_device) {
|
||||
*out_device_picked = string(*maybe_gpu_device);
|
||||
} else if (maybe_unknown_device) {
|
||||
*out_device_picked = string(*maybe_unknown_device);
|
||||
} else {
|
||||
*out_device_picked = string(*maybe_cpu_device);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
#undef FAILED_TO_PICK_DEVICE
|
||||
}
|
||||
|
||||
Status PickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
string* out_device_picked) {
|
||||
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
|
||||
/*out_can_pick_device=*/nullptr,
|
||||
out_device_picked);
|
||||
}
|
||||
|
||||
Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device) {
|
||||
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
|
||||
out_can_pick_device,
|
||||
/*out_device_picked=*/nullptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct XlaGlobalJitLevel {
|
||||
OptimizerOptions::GlobalJitLevel single_gpu;
|
||||
@ -425,4 +314,8 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
|
||||
return name_attr_pair.second.has_func();
|
||||
});
|
||||
}
|
||||
bool IsShapeConsumerOp(const Node& node) {
|
||||
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
|
||||
node.type_string() == "Size";
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -46,9 +46,6 @@ extern const char* const kXlaCompileTimeConstantInputsAttr;
|
||||
|
||||
using OrderedNodeSet = std::set<Node*, NodeComparatorID>;
|
||||
|
||||
// Returns the DeviceType corresponding to 'device'.
|
||||
Status DeviceToDeviceType(const string& device, DeviceType* device_type);
|
||||
|
||||
// Returns true if `node` has a ref tensor input that it forwards to its output.
|
||||
bool HasForwardedRefInput(const Node& node);
|
||||
|
||||
@ -74,51 +71,6 @@ void RemoveFromXlaCluster(Node* node);
|
||||
// Returns true if `node` has a DT_RESOURCE typed input or output.
|
||||
bool HasResourceInputOrOutput(const Node& node);
|
||||
|
||||
// Picks the device for which XLA should compile a cluster that contains
|
||||
// operations placed in devices in `device_names`. For instance a cluster that
|
||||
// contains operations solely placed on the CPU will be compiled into a CPU
|
||||
// executable by XLA, whereas a cluster that contains operations placed on the
|
||||
// CPU and also operations placed on the GPU will be compiled into a GPU
|
||||
// executable.
|
||||
//
|
||||
// Returns a non-OK Status if no unambiguous choice of device exists.
|
||||
//
|
||||
// We choose the device using the following rules:
|
||||
//
|
||||
// - It is an error for `device_names` to contain more than one device of the
|
||||
// same type.
|
||||
// - GPU is preferred over CPU.
|
||||
// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are
|
||||
// preferred over CPU.
|
||||
// - XLA devices count as "unrecognized devices".
|
||||
//
|
||||
// This set of rules above implicitly assume that XLA:GPU can compile all
|
||||
// operations in the cluster that XLA:CPU can compile, and if
|
||||
// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile
|
||||
// all operations in the cluster that XLA:CPU can compile.
|
||||
//
|
||||
// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of
|
||||
// the following things:
|
||||
//
|
||||
// - Let MarkForCompilationPass not inject CPU-placed operations into clusters
|
||||
// that will run on unknown devices (because the unknown XLA backend may not
|
||||
// support every operation supported by CPU).
|
||||
// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster
|
||||
// that contains nodes placed on both the CPU and on unknown devices. In this
|
||||
// case it is the responsibility of the optimization pass that injected the
|
||||
// CPU nodes into the cluster to ensure that these nodes can be compiled by
|
||||
// the unknown XLA backend.
|
||||
Status PickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
string* out_device_picked);
|
||||
|
||||
// This is like `PickDeviceForXla` except that it returns false (instead of a
|
||||
// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device
|
||||
// exists.
|
||||
Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device);
|
||||
|
||||
// Determines the global jit level based on GraphOptimizationPassOptions,
|
||||
// --tf_xla_auto_jit and whether the graph is a single GPU graph.
|
||||
OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
|
||||
@ -131,6 +83,10 @@ bool IsSingleGpuGraph(const Graph& g);
|
||||
// Returns true if it is possible (but not guaranteed) that `n` calls a
|
||||
// function.
|
||||
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
|
||||
|
||||
// Returns true if `node` an operator that consumes only the shape of its input,
|
||||
// not the data itself.
|
||||
bool IsShapeConsumerOp(const Node& node);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
@ -91,67 +91,9 @@ TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
|
||||
EXPECT_FALSE(ok);
|
||||
}
|
||||
|
||||
void CheckPickDeviceResult(absl::string_view expected_result,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
absl::Span<const absl::string_view> inputs) {
|
||||
std::vector<string> inputs_string;
|
||||
absl::c_transform(inputs, std::back_inserter(inputs_string),
|
||||
[](absl::string_view sv) { return string(sv); });
|
||||
string result;
|
||||
TF_ASSERT_OK(
|
||||
PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result))
|
||||
<< "inputs = [" << absl::StrJoin(inputs, ", ")
|
||||
<< "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu
|
||||
<< ", expected_result=" << expected_result;
|
||||
EXPECT_EQ(result, expected_result);
|
||||
}
|
||||
|
||||
void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu,
|
||||
absl::Span<const absl::string_view> inputs) {
|
||||
std::vector<string> inputs_string;
|
||||
absl::c_transform(inputs, std::back_inserter(inputs_string),
|
||||
[](absl::string_view sv) { return string(sv); });
|
||||
string result;
|
||||
EXPECT_FALSE(
|
||||
PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result)
|
||||
.ok());
|
||||
}
|
||||
|
||||
const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0";
|
||||
|
||||
const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
|
||||
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
|
||||
|
||||
TEST(PickDeviceForXla, UniqueDevice) {
|
||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, DeviceOrder) {
|
||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
|
||||
CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, MultipleUnknownDevices) {
|
||||
CheckPickDeviceHasError(false, {kXPU0, kXPU1});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, GpuAndUnknown) {
|
||||
CheckPickDeviceHasError(false, {kGPU0, kXPU1});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, UnknownAndCpu) {
|
||||
CheckPickDeviceHasError(false, {kXPU0, kCPU1});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
|
||||
CheckPickDeviceHasError(false, {kCPU0, kCPU1});
|
||||
CheckPickDeviceHasError(false, {kGPU0, kGPU1});
|
||||
CheckPickDeviceHasError(false, {kXPU0, kXPU1});
|
||||
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
|
||||
}
|
||||
|
||||
TEST(IsSingleGpuGraph, ReturnsTrue) {
|
||||
Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();
|
||||
|
@ -60,6 +60,7 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
@ -47,6 +48,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -380,14 +382,17 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
|
||||
op_kernel->IsExpensive());
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
return absl::StrCat(op_kernel->name(), ":", op_kernel->type_string());
|
||||
},
|
||||
profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
|
||||
op_kernel->ComputeAsync(context, done);
|
||||
}
|
||||
|
||||
Status XlaDevice::Sync() {
|
||||
VLOG(1) << "XlaDevice::Sync";
|
||||
tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true);
|
||||
profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
|
||||
std::shared_ptr<se::Stream> stream;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
@ -428,13 +433,12 @@ void XlaDevice::Sync(const DoneCallback& done) {
|
||||
// that everything enqueued onto the stream (i.e., the device) at this very
|
||||
// moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
|
||||
// This achieves a device-wide sync.
|
||||
stream->ThenEnqueueOnBackgroundThread(
|
||||
[stream, done](se::StreamExecutor*) {
|
||||
tracing::ScopedActivity activity("XlaDevice::Sync::Callback",
|
||||
/*is_expensive=*/true);
|
||||
done(stream->ok() ? Status::OK()
|
||||
: errors::Internal("XlaDevice::Sync() failed."));
|
||||
});
|
||||
stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
|
||||
profiler::TraceMe activity("XlaDevice::Sync::Callback",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
done(stream->ok() ? Status::OK()
|
||||
: errors::Internal("XlaDevice::Sync() failed."));
|
||||
});
|
||||
}
|
||||
|
||||
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
@ -458,11 +462,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
||||
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
||||
Notification n;
|
||||
device_context->CopyCPUTensorToDevice(&parsed, this, ©,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
device_context->CopyCPUTensorToDevice(
|
||||
&parsed, this, ©,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
},
|
||||
true /*sync_dst_compute*/);
|
||||
n.WaitForNotification();
|
||||
*tensor = copy;
|
||||
}
|
||||
|
@ -65,6 +65,9 @@ absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
|
||||
tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
|
||||
tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
|
||||
tf_stats.bytes_limit = se_stats->bytes_limit;
|
||||
tf_stats.bytes_reserved = se_stats->bytes_reserved;
|
||||
tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
|
||||
tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
|
||||
return tf_stats;
|
||||
}
|
||||
|
||||
@ -106,7 +109,8 @@ void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
|
||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const {
|
||||
StatusCallback done,
|
||||
bool sync_dst_compute) const {
|
||||
if (cpu_tensor->NumElements() == 0) {
|
||||
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
||||
done(Status::OK());
|
||||
@ -242,16 +246,25 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
cpu_tensor, &literal));
|
||||
|
||||
TensorReference ref(*device_tensor);
|
||||
const bool device_allows_sync_on_completion =
|
||||
device->AllowsSyncOnCompletion();
|
||||
// Explicitly capture device_to_host_stream to make sure the stream is alive
|
||||
// before the transfer finishes.
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
|
||||
[ref, xla_tensor, done, device_to_host_stream](xla::Status status) {
|
||||
done([&]() -> Status {
|
||||
VLOG(2) << "Transfer from device as literal: "
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
return status;
|
||||
}());
|
||||
[ref, xla_tensor, done, device_to_host_stream,
|
||||
device_allows_sync_on_completion](xla::Status status) {
|
||||
Status done_status = status;
|
||||
VLOG(2) << "Transfer from device as literal: "
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
// For devices don't allow sync on completion, the device execution is
|
||||
// deferred. We check the execution stream status here to avoid wrong
|
||||
// results from a failed stream being propogated to following
|
||||
// host-side ops.
|
||||
if (!device_allows_sync_on_completion) {
|
||||
done_status.Update(xla_tensor->RefreshStatusOfStreams());
|
||||
}
|
||||
done(done_status);
|
||||
ref.Unref();
|
||||
});
|
||||
}
|
||||
|
@ -61,8 +61,8 @@ class XlaDeviceContext : public DeviceContext {
|
||||
thread::ThreadPool* thread_pool);
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const override;
|
||||
Tensor* device_tensor, StatusCallback done,
|
||||
bool sync_dst_compute) const override;
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
absl::string_view tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
|
@ -95,6 +95,7 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -63,6 +63,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||
registration);
|
||||
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -132,7 +133,8 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
|
||||
// cluster because we would not handle variable updates correctly. Any
|
||||
// locks we have already acquired will be released when the VariableInfo
|
||||
// objects are destroyed.
|
||||
return errors::Internal("Duplicate variable passed to XLA cluster");
|
||||
// TODO(b/128495870) Add support for passing aliased resource variables.
|
||||
return errors::Unimplemented("Duplicate variable passed to XLA cluster");
|
||||
}
|
||||
VLOG(4) << "Acquiring lock for variable "
|
||||
<< reinterpret_cast<void*>(variable);
|
||||
@ -166,11 +168,11 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
|
||||
: xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
|
||||
: se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
|
||||
|
||||
XlaAllocator::~XlaAllocator() {}
|
||||
|
||||
xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) {
|
||||
AllocationAttributes attrs;
|
||||
attrs.no_retry_on_failure = !retry_on_failure;
|
||||
@ -182,8 +184,8 @@ xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
"Out of memory while trying to allocate ", size, " bytes.");
|
||||
}
|
||||
}
|
||||
return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
|
||||
device_ordinal, this);
|
||||
return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
|
||||
device_ordinal, this);
|
||||
}
|
||||
|
||||
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
|
||||
@ -192,7 +194,7 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
|
||||
}
|
||||
|
||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
|
||||
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||
: client_(client),
|
||||
xla_allocator_(xla_allocator),
|
||||
@ -242,7 +244,8 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
} else {
|
||||
CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
|
||||
CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
|
||||
on_device_shape))
|
||||
<< "On-device shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
|
||||
<< " not the same as on-host shape "
|
||||
@ -371,7 +374,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
} else {
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
@ -432,7 +435,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
write.type, write.shape, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
|
@ -23,14 +23,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class XlaAllocator;
|
||||
@ -108,11 +107,11 @@ Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
|
||||
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
|
||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||
class XlaAllocator : public xla::DeviceMemoryAllocator {
|
||||
class XlaAllocator : public se::DeviceMemoryAllocator {
|
||||
public:
|
||||
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
|
||||
~XlaAllocator() override;
|
||||
xla::StatusOr<xla::OwningDeviceMemory> Allocate(
|
||||
xla::StatusOr<se::OwningDeviceMemory> Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) override;
|
||||
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
|
||||
|
||||
@ -129,6 +128,50 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
|
||||
Allocator* wrapped_;
|
||||
};
|
||||
|
||||
// Adapter class that wraps per-device TF allocators as an XLA allocator.
|
||||
// Assumes that the Tensorflow allocator permits asynchronous deallocation;
|
||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||
class MultiDeviceAdapter : public se::DeviceMemoryAllocator {
|
||||
public:
|
||||
MultiDeviceAdapter(
|
||||
const se::Platform* platform,
|
||||
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators)
|
||||
: DeviceMemoryAllocator(platform),
|
||||
tf_allocators_(std::move(tf_allocators)) {
|
||||
for (const auto& tf_allocator : tf_allocators_) {
|
||||
per_device_allocators_.emplace_back(platform, tf_allocator.get());
|
||||
}
|
||||
}
|
||||
|
||||
xla::StatusOr<se::OwningDeviceMemory> Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) override {
|
||||
CHECK_LT(device_ordinal, per_device_allocators_.size());
|
||||
return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size,
|
||||
retry_on_failure);
|
||||
}
|
||||
|
||||
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override {
|
||||
CHECK_LT(device_ordinal, per_device_allocators_.size());
|
||||
return per_device_allocators_[device_ordinal].Deallocate(device_ordinal,
|
||||
mem);
|
||||
}
|
||||
|
||||
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
|
||||
// before GPU execution takes place. Tensorflow uses the ordering of the main
|
||||
// compute stream to enforce a happens-before relationship between a memory
|
||||
// allocation and code that reuses the same memory. If Tensorflow adds
|
||||
// support for multiple GPU streams or allocators with different ordering
|
||||
// requirements, this code may need to change.
|
||||
// (This attribute has no effect on CPU.)
|
||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
||||
|
||||
private:
|
||||
std::vector<tensorflow::XlaAllocator> per_device_allocators_;
|
||||
// The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does
|
||||
// not take ownership of its underlying Allocator).
|
||||
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators_;
|
||||
};
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
class XlaComputationLaunchContext {
|
||||
@ -142,7 +185,7 @@ class XlaComputationLaunchContext {
|
||||
// because we track inter-stream dependencies through events inside XlaTensor
|
||||
// objects.
|
||||
XlaComputationLaunchContext(xla::LocalClient* client,
|
||||
xla::DeviceMemoryAllocator* xla_allocator,
|
||||
se::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors,
|
||||
bool use_multiple_streams);
|
||||
|
||||
@ -186,7 +229,7 @@ class XlaComputationLaunchContext {
|
||||
|
||||
private:
|
||||
xla::LocalClient* client_;
|
||||
xla::DeviceMemoryAllocator* xla_allocator_;
|
||||
se::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
bool use_multiple_streams_;
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
||||
|
@ -59,11 +59,11 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype,
|
||||
xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
|
||||
uint64 size =
|
||||
client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
|
||||
TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
|
||||
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
|
||||
client->backend().memory_allocator()->Allocate(
|
||||
device_ordinal, size, /*retry_on_failure=*/false));
|
||||
// Move our buffer into shaped_buffer, which takes ownership of it.
|
||||
index_to_buffer.second = buffer.Forget();
|
||||
index_to_buffer.second = buffer.Release();
|
||||
}
|
||||
|
||||
VLOG(4) << shaped_buffer.ToString();
|
||||
@ -97,6 +97,15 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
|
||||
streams_defined_on_ = {stream};
|
||||
}
|
||||
|
||||
Status XlaTensor::RefreshStatusOfStreams() {
|
||||
mutex_lock lock(mu_);
|
||||
Status status;
|
||||
for (se::Stream* stream : streams_defined_on_) {
|
||||
status.Update(stream->RefreshStatus());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
|
||||
// device-side tensors, which are either CPU or GPU memory pointers. This works
|
||||
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
|
||||
|
@ -102,6 +102,10 @@ class XlaTensor {
|
||||
void ResetDefinitionEvent(std::shared_ptr<se::Event> event,
|
||||
se::Stream* stream);
|
||||
|
||||
// Refresh the status of streams_defined_on_. Return the first not-OK stream's
|
||||
// status or OK.
|
||||
Status RefreshStatusOfStreams();
|
||||
|
||||
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
|
||||
static XlaTensor* FromOpaquePointer(void* ptr);
|
||||
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
|
||||
|
@ -65,6 +65,7 @@ py_test(
|
||||
name = "xla_test_test",
|
||||
size = "small",
|
||||
srcs = ["xla_test_test.py"],
|
||||
python_version = "PY2",
|
||||
deps = [
|
||||
":xla_test",
|
||||
],
|
||||
@ -458,10 +459,6 @@ tf_xla_py_test(
|
||||
name = "extract_image_patches_op_test",
|
||||
size = "small",
|
||||
srcs = ["extract_image_patches_op_test.py"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -41,7 +41,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase):
|
||||
all_lr = [1.0, 0.5, 0.1]
|
||||
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
for grad in all_grad:
|
||||
for lr in all_lr:
|
||||
var0_init = [1.0, 2.0]
|
||||
|
@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdagradDAWithoutRegularizationBasic1(self):
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
||||
@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdagradDAwithoutRegularizationBasic2(self):
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdagradDAWithL1(self):
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdagradDAWithL1_L2(self):
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
|
@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
@ -59,7 +59,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
@ -87,7 +87,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testSharing(self):
|
||||
for dtype in self.float_types:
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
|
@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
# TODO: test fails for float16 due to excessive precision requirements.
|
||||
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
|
||||
continue
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
# Initialize variables for numpy implementation.
|
||||
@ -99,7 +99,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
# TODO: test fails for float16 due to excessive precision requirements.
|
||||
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
|
||||
continue
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
# Initialize variables for numpy implementation.
|
||||
@ -142,7 +142,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
# TODO: test fails for float16 due to excessive precision requirements.
|
||||
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
|
||||
continue
|
||||
with self.cached_session(), self.test_scope():
|
||||
with self.session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
# Initialize variables for numpy implementation.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user