Merge branch 'master' into micro_port_conv

This commit is contained in:
Jens Elofsson 2019-04-24 10:22:03 +02:00 committed by GitHub
commit 4aee543969
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2343 changed files with 154106 additions and 73192 deletions

View File

@ -18,10 +18,11 @@ about: Use this template for reporting a bug or a performance issue.
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**

View File

@ -25,7 +25,7 @@ networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
compatible API's for C++, Go, Java, JavaScript and Swift.
compatible API's for C++, Go, Java, JavaScript, and Swift.
Keep up to date with release announcements and security updates by
subscribing to
@ -85,7 +85,7 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's
uphold this code.**
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs, so please see
tracking requests and bugs, please see
[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss)
for general questions and discussion, and please direct specific questions to
[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
@ -115,14 +115,15 @@ The TensorFlow project strives to abide by generally accepted best practices in
### Community Supported Builds
Build Type | Status | Artifacts
-------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5 and 3.6** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/)
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, and 3.6** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 pypi](https://tensorflow.pypi.thoth-station.ninja/index/)
## For more information

View File

@ -1,3 +1,10 @@
# Release 1.12.2
## Bug Fixes and Other Changes
* Fixes a potential security vulnerability where carefully crafted GIF images
can produce a null pointer dereference during decoding.
# Release 1.13.0
## Major Features and Improvements
@ -15,44 +22,86 @@
## Bug Fixes and Other Changes
* Documentation
* Update the doc with the details about the rounding mode used in quantize_and_dequantize_v2.
* Clarify that tensorflow::port::InitMain() _should_ be called before using the TensorFlow library. Programs failing to do this are not portable to all platforms.
* Update the doc with the details about the rounding mode used in
quantize_and_dequantize_v2.
* Clarify that tensorflow::port::InitMain() _should_ be called before
using the TensorFlow library. Programs failing to do this are not
portable to all platforms.
* Deprecations and Symbol renames.
* Removing deprecations for the following endpoints: `tf.acos`, `tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`, `tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`, `tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`, `tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`, `tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`, `tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan`
* Removing deprecations for the following endpoints: `tf.acos`,
`tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`,
`tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`,
`tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`,
`tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`,
`tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`,
`tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan`
* Deprecate `tf.data.Dataset.shard`.
* Deprecate `saved_model.loader.load` which is replaced by `saved_model.load` and `saved_model.main_op`, which will be replaced by `saved_model.main_op` in V2.
* Deprecate tf.QUANTIZED_DTYPES. The official new symbol is tf.dtypes.QUANTIZED_DTYPES.
* Deprecate `saved_model.loader.load` which is replaced by
`saved_model.load` and `saved_model.main_op`, which will be replaced by
`saved_model.main_op` in V2.
* Deprecate tf.QUANTIZED_DTYPES. The official new symbol is
tf.dtypes.QUANTIZED_DTYPES.
* Update sklearn imports for deprecated packages.
* Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of `Dataset.range`.
* Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of `tf.train.confusion_matrix`.
* Add `tf.dtypes.` endpoint for every constant in dtypes.py; moving endpoints in versions.py to corresponding endpoints in `tf.sysconfig.` and `tf.version.`; moving all constants under `tf.saved_model` submodules to `tf.saved_model` module. New endpoints are added in V1 and V2 but existing endpoint removals are only applied in V2.
* Deprecates behavior where device assignment overrides collocation constraints inside a collocation context manager.
* Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of
`Dataset.range`.
* Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of
`tf.train.confusion_matrix`.
* Add `tf.dtypes.` endpoint for every constant in dtypes.py. Moving
endpoints in versions.py to corresponding endpoints in `tf.sysconfig.`
and `tf.version.`. Moving all constants under `tf.saved_model`
submodules to `tf.saved_model` module. New endpoints are added in V1 and
V2 but existing endpoint removals are only applied in V2.
* Deprecates behavior where device assignment overrides collocation
constraints inside a collocation context manager.
* Keras & Python API
* Add to Keras functionality analogous to `tf.register_tensor_conversion_function`.
* Subclassed Keras models can now be saved through `tf.contrib.saved_model.save_keras_model`.
* Add to Keras functionality analogous to
`tf.register_tensor_conversion_function`.
* Subclassed Keras models can now be saved through
`tf.contrib.saved_model.save_keras_model`.
* `LinearOperator.matmul` now returns a new `LinearOperator`.
* New ops and improved op functionality
* Add a Nearest Neighbor Resize op.
* Add an `ignore_unknown` argument to `parse_values` which suppresses ValueError for unknown hyperparameter types. Such * Add `tf.linalg.matvec` convenience function.
* `tf.einsum()`raises `ValueError` for unsupported equations like `"ii->"`.
* Add an `ignore_unknown` argument to `parse_values` which suppresses
ValueError for unknown hyperparameter types. Such * Add
`tf.linalg.matvec` convenience function.
* `tf.einsum()`raises `ValueError` for unsupported equations like
`"ii->"`.
* Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`.
* Add LU decomposition op.
* Add quantile loss to gradient boosted trees in estimator.
* Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding algorithm.
* Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`, `unicode_split`, `unicode_split_with_offset`, and `unicode_transcode` ops. Amongst other things, this Op adds the ability to encode, decode, and transcode a variety of input text encoding formats into the main Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE)
* Add "unit" attribute to the substr op, which allows obtaining the substring of a string containing unicode characters.
* Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding
algorithm.
* Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`,
`unicode_split`, `unicode_split_with_offset`, and `unicode_transcode`
ops. Amongst other things, this Op adds the ability to encode, decode,
and transcode a variety of input text encoding formats into the main
Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE)
* Add "unit" attribute to the substr op, which allows obtaining the
substring of a string containing unicode characters.
* Broadcasting support for Ragged Tensors.
* `SpaceToDepth` supports uint8 data type.
* Support multi-label quantile regression in estimator.
* We now use "div" as the default partition_strategy in `tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and `tf.nn.nce_loss`.
hyperparameter are ignored.
* We now use "div" as the default partition_strategy in
`tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and
`tf.nn.nce_loss`. hyperparameter are ignored.
* Performance
* Improve performance of GPU cumsum/cumprod by up to 300x.
* Added support for weight decay in most TPU embedding optimizers, including AdamW and MomentumW.
* Added support for weight decay in most TPU embedding optimizers,
including AdamW and MomentumW.
* TensorFlow 2.0 Development
* Add a command line tool to convert to TF2.0, tf_upgrade_v2
* Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0.
* Change the default recurrent activation function for LSTM from 'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is 'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we change the default for CPU mode to sigmoid as well. With that, the default LSTM will be compatible with both CPU and GPU kernel. This will enable user with GPU to use CuDNN kernel by default and get a 10x performance boost in training. Note that this is checkpoint breaking change. If user want to use their 1.x pre-trained checkpoint, please construct the layer with LSTM(recurrent_activation='hard_sigmoid') to fallback to 1.x behavior.
* Change the default recurrent activation function for LSTM from
'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is
'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend
between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we
change the default for CPU mode to sigmoid as well. With that, the
default LSTM will be compatible with both CPU and GPU kernel. This will
enable user with GPU to use CuDNN kernel by default and get a 10x
performance boost in training. Note that this is checkpoint breaking
change. If user want to use their 1.x pre-trained checkpoint, please
construct the layer with LSTM(recurrent_activation='hard_sigmoid') to
fallback to 1.x behavior.
* TensorFlow Lite
* Move from `tensorflow/contrib/lite` to `tensorflow/lite`.
* Add experimental Java API for injecting TensorFlow Lite delegates
@ -60,53 +109,98 @@
* `tf.contrib`:
* Add Apache Ignite Filesystem plugin to support accessing Apache IGFS.
* Dropout now takes `rate` argument, `keep_prob` is deprecated.
* Estimator occurrences references `tf.contrib.estimator` were changed to `tf.estimator`:
* `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator`
* `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator`
* Estimator occurrences references `tf.contrib.estimator` were changed to
`tf.estimator`:
* `tf.contrib.estimator.BaselineEstimator` with
`tf.estimator.BaselineEstimator`
* `tf.contrib.estimator.DNNLinearCombinedEstimator` with
`tf.estimator.DNNLinearCombinedEstimator`
* `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator`
* `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator`
* `tf.contrib.estimator.InMemoryEvaluatorHook` and tf.estimator.experimental.InMemoryEvaluatorHook`.
* `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.make_stop_at_checkpoint_step_hook`.
* Expose `tf.distribute.Strategy as the new name for tf.contrib.distribute.DistributionStrategy.
* `tf.contrib.estimator.LinearEstimator` with
`tf.estimator.LinearEstimator`
* `tf.contrib.estimator.InMemoryEvaluatorHook` and
tf.estimator.experimental.InMemoryEvaluatorHook`.
* `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`.
* Expose `tf.distribute.Strategy as the new name for
tf.contrib.distribute.DistributionStrategy.
* Migrate linear optimizer from contrib to core.
* Move `tf.contrib.signal` to `tf.signal` (preserving aliases in tf.contrib.signal).
* Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`.
* Move `tf.contrib.signal` to `tf.signal` (preserving aliases in
tf.contrib.signal).
* Users of `tf.contrib.estimator.export_all_saved_models` and related
should switch to
`tf.estimator.Estimator.experimental_export_all_saved_models`.
* tf.data:
* Add `tf.data.experimental.StatsOptions()`, to configure options to collect statistics from `tf.data.Dataset` pipeline using `StatsAggregator`. Add nested option, `experimental_stats` (which takes a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`. Deprecates `tf.data.experimental.set_stats_agregator`.
* Add `tf.data.experimental.StatsOptions()`, to configure options to
collect statistics from `tf.data.Dataset` pipeline using
`StatsAggregator`. Add nested option, `experimental_stats` (which takes
a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`.
Deprecates `tf.data.experimental.set_stats_agregator`.
* Performance optimizations:
* Add `tf.data.experimental.OptimizationOptions()`, to configure options to enable `tf.data` performance optimizations. Add nested option, `experimental_optimization` (which takes a `tf.data.experimental.OptimizationOptions` object), to `tf.data.Options`. Remove performance optimization options from `tf.data.Options`, and add them under `tf.data.experimental.OptimizationOptions` instead.
* Enable `map_and_batch_fusion` and `noop_elimination` optimizations by default. They can be disabled by configuring `tf.data.experimental.OptimizationOptions` to set `map_and_batch = False` or `noop_elimination = False` respectively. To disable all default optimizations, set `apply_default_optimizations = False`.
* Add `tf.data.experimental.OptimizationOptions()`, to configure options
to enable `tf.data` performance optimizations. Add nested option,
`experimental_optimization` (which takes a
`tf.data.experimental.OptimizationOptions` object), to
`tf.data.Options`. Remove performance optimization options from
`tf.data.Options`, and add them under
`tf.data.experimental.OptimizationOptions` instead.
* Enable `map_and_batch_fusion` and `noop_elimination` optimizations by
default. They can be disabled by configuring
`tf.data.experimental.OptimizationOptions` to set `map_and_batch =
False` or `noop_elimination = False` respectively. To disable all
default optimizations, set `apply_default_optimizations = False`.
* Support parallel map in `map_and_filter_fusion`.
* Disable static optimizations for input pipelines that use non-resource `tf.Variable`s.
* Disable static optimizations for input pipelines that use non-resource
`tf.Variable`s.
* Add NUMA-aware MapAndBatch dataset.
* Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it from V2, and added tf.compat.v1.data.make_one_shot_iterator()`.
* Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`.
* Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it
from V2, and added tf.compat.v1.data.make_one_shot_iterator()`.
* Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed
it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`.
* Enable nested dataset support in core `tf.data` transformations.
* For `tf.data.Dataset` implementers: Added `tf.data.Dataset._element_structured property` to replace `Dataset.output_{types,shapes,classes}`.
* Make `num_parallel_calls` of `tf.data.Dataset.interleave` and `tf.data.Dataset.map` work in Eager mode.
* For `tf.data.Dataset` implementers: Added
`tf.data.Dataset._element_structured property` to replace
`Dataset.output_{types,shapes,classes}`.
* Make `num_parallel_calls` of `tf.data.Dataset.interleave` and
`tf.data.Dataset.map` work in Eager mode.
* Toolchains
* Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`.
* Added bounds checking to printing deprecation warnings.
* Upgraded CUDA dependency to 10.0
* To build with Android NDK r14b, add "#include <linux/compiler.h>" to android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h
* Removed `:android_tensorflow_lib_selective_registration*` targets, use `:android_tensorflow_lib_lite*` targets instead.
* To build with Android NDK r14b, add "#include <linux/compiler.h>" to
android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h
* Removed `:android_tensorflow_lib_selective_registration*` targets, use
`:android_tensorflow_lib_lite*` targets instead.
* XLA
* Move `RoundToEven` function to xla/client/lib/math.h.
* A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1" or "true" allows the debug options passed within an XRTCompile op to be passed directly to the XLA compilation backend. If such variable is not set (service side), only a restricted set will be passed through.
* Allow the XRTCompile op to return the ProgramShape resulted form the XLA compilation as a second return argument.
* A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1"
or "true" allows the debug options passed within an XRTCompile op to be
passed directly to the XLA compilation backend. If such variable is not
set (service side), only a restricted set will be passed through.
* Allow the XRTCompile op to return the ProgramShape resulted form the XLA
compilation as a second return argument.
* XLA HLO graphs can now be rendered as SVG/HTML.
* Estimator
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator`
* Replace all occurences of `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator`
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator`
* Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator`
* Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`.
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with
`tf.estimator.BaselineEstimator`
* Replace all occurences of
`tf.contrib.estimator.DNNLinearCombinedEstimator` with
`tf.estimator.DNNLinearCombinedEstimator`
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with
`tf.estimator.DNNEstimator`
* Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with
`tf.estimator.LinearEstimator`
* Users of `tf.contrib.estimator.export_all_saved_models` and related
should switch to
`tf.estimator.Estimator.experimental_export_all_saved_models`.
* Update `regression_head` to the new Head API for Canned Estimator V2.
* Switch `multi_class_head` to Head API for Canned Estimator V2.
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook` and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.InMemoryEvaluatorHook` and `tf.estimator.experimental.make_stop_at_checkpoint_step_hook`
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook`
and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
`tf.estimator.experimental.InMemoryEvaluatorHook` and
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`
* Migrate linear optimizer from contrib to core.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:

View File

@ -43,47 +43,37 @@ remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "4b90786009fa8df25230442244bad2832ba8d6bc4987f68150a7de59c8827e90",
strip_prefix = "rules_apple-0.14.0",
urls = ["https://github.com/bazelbuild/rules_apple/archive/0.14.0.tar.gz"],
)
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"],
)
http_archive(
name = "bazel_skylib",
sha256 = "2c62d8cd4ab1e65c08647eb4afe38f51591f43f7f0885e7769832fa137633dcb",
strip_prefix = "bazel-skylib-0.7.0",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.7.0.tar.gz"],
)
sha256 = "8f32e2839fba28d549e1670dbed83606dd339a9f7489118e481814d61738270f",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.14.0/rules_apple.0.14.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "835663c4bb02f4bf01dce8a2a176df7fa682dbb867d3698ae12258c1628bb8f0",
strip_prefix = "apple_support-0.5.0",
urls = ["https://github.com/bazelbuild/apple_support/archive/0.5.0.tar.gz"],
)
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "32d124878cd49775d84f59ba90440c8b23b7c775aec8fec1978f751c76ddee8a",
strip_prefix = "rules_swift-0.7.0",
urls = ["https://github.com/bazelbuild/rules_swift/archive/0.7.0.tar.gz"],
)
sha256 = "31aad005a9c4e56b256125844ad05eb27c88303502d74138186f9083479f93a6",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.8.0/rules_swift.0.8.0.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.2.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.2.0.zip"],
)
# Use swift_rules_dependencies to fetch the tolchains.
# Since we defined all the "git_repository" rules above, the following call will
# skip redefining them.
strip_prefix = "swift-protobuf-1.4.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.4.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
swift_rules_dependencies()

View File

@ -33,13 +33,11 @@ except ImportError:
from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top
_DEFAULT_CUDA_VERSION = '10.0'
_DEFAULT_CUDA_VERSION = '10'
_DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_TENSORRT_VERSION = '5'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@ -69,11 +67,6 @@ IOS_FILES = [
'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec',
]
if platform.machine() == 'ppc64le':
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
else:
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
class UserInputError(Exception):
pass
@ -300,9 +293,9 @@ def get_var(environ_cp,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
query_item: string for feature related to the variable, e.g. "Hadoop File
System".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "CUDA for
Nvidia GPUs".
enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled.
@ -383,9 +376,9 @@ def set_build_var(environ_cp,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
query_item: string for feature related to the variable, e.g. "Hadoop File
System".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "CUDA for
Nvidia GPUs".
option_name: string for option to define in .bazelrc.
enabled_by_default: boolean for default behavior.
bazel_config_name: Name for Bazel --config argument to enable build feature.
@ -418,9 +411,9 @@ def set_action_env_var(environ_cp,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
query_item: string for feature related to the variable, e.g. "Hadoop File
System".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "CUDA for
Nvidia GPUs".
enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled.
@ -463,8 +456,8 @@ def check_bazel_version(min_version, max_version):
"""Check installed bazel version is between min_version and max_version.
Args:
min_version: string for minimum bazel version.
max_version: string for maximum bazel version.
min_version: string for minimum bazel version (must exist!).
max_version: string for maximum bazel version (must exist!).
Returns:
The bazel version detected.
@ -577,7 +570,7 @@ def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
ask_for_var: string for how to ask for user input.
var_default: default value string.
@ -862,154 +855,39 @@ def reformat_version_sequence(version_str, sequence_count):
return '.'.join(v[:sequence_count])
def set_tf_cuda_paths(environ_cp):
"""Set TF_CUDA_PATHS."""
ask_cuda_paths = (
'Please specify the comma-separated list of base paths to look for CUDA '
'libraries and headers. [Leave empty to use the default]: ')
tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS',
ask_cuda_paths, '')
if tf_cuda_paths:
environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
"""Set TF_CUDA_VERSION."""
ask_cuda_version = (
'Please specify the CUDA SDK version you want to use. '
'[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
tf_cuda_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDA_VERSION',
ask_cuda_version,
_DEFAULT_CUDA_VERSION)
tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2)
# Find out where the CUDA toolkit is installed
default_cuda_path = _DEFAULT_CUDA_PATH
if is_windows() or is_cygwin():
default_cuda_path = cygpath(
environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
elif is_linux():
# If the default doesn't exist, try an alternative default.
if (not os.path.exists(default_cuda_path)
) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX):
default_cuda_path = _DEFAULT_CUDA_PATH_LINUX
ask_cuda_path = ('Please specify the location where CUDA %s toolkit is'
' installed. Refer to README.md for more details. '
'[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
cuda_toolkit_path = get_from_env_or_user_or_default(environ_cp,
'CUDA_TOOLKIT_PATH',
ask_cuda_path,
default_cuda_path)
if is_windows() or is_cygwin():
cuda_toolkit_path = cygpath(cuda_toolkit_path)
if is_windows():
cuda_rt_lib_paths = ['lib/x64/cudart.lib']
elif is_linux():
cuda_rt_lib_paths = [
'%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [
'lib64',
'lib/powerpc64le-linux-gnu',
'lib/x86_64-linux-gnu',
]
]
elif is_macos():
cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
cuda_toolkit_paths_full = [
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
]
if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
break
# Reset and retry
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
(tf_cuda_version, cuda_toolkit_paths_full))
environ_cp['TF_CUDA_VERSION'] = ''
environ_cp['CUDA_TOOLKIT_PATH'] = ''
else:
raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION
environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path
write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path)
environ_cp['TF_CUDA_VERSION'] = tf_cuda_version
write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version)
def set_tf_cudnn_version(environ_cp):
"""Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
"""Set TF_CUDNN_VERSION."""
ask_cudnn_version = (
'Please specify the cuDNN version you want to use. '
'[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_cudnn_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDNN_VERSION',
ask_cudnn_version,
_DEFAULT_CUDNN_VERSION)
tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1)
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
'installed. Refer to README.md for more details. [Default'
' is %s]: ') % (tf_cudnn_version, default_cudnn_path)
cudnn_install_path = get_from_env_or_user_or_default(
environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
cudnn_install_path = os.path.realpath(
os.path.expanduser(cudnn_install_path))
if is_windows() or is_cygwin():
cudnn_install_path = cygpath(cudnn_install_path)
if is_windows():
cuda_dnn_lib_path = 'lib/x64/cudnn.lib'
cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib'
elif is_linux():
cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version
cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version
elif is_macos():
cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version
cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version
cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path)
cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path,
cuda_dnn_lib_alt_path)
if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists(
cuda_dnn_lib_alt_path_full):
break
# Try another alternative for Linux
if is_linux():
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
cudnn_path_from_ldconfig)
if cudnn_path_from_ldconfig:
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
if os.path.exists('%s.%s' %
(cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
# Reset and Retry
print(
'Invalid path to cuDNN %s toolkit. None of the following files can be '
'found:' % tf_cudnn_version)
print(cuda_dnn_lib_path_full)
print(cuda_dnn_lib_alt_path_full)
if is_linux():
print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version))
environ_cp['TF_CUDNN_VERSION'] = ''
else:
raise UserInputError('Invalid TF_CUDNN setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION
environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path
write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path)
environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
@ -1041,253 +919,38 @@ def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
return cudnn_ok and cuda_ok
def set_tf_tensorrt_install_path(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
Adapted from code contributed by Sami Kama (https://github.com/samikama).
Args:
environ_cp: copy of the os.environ.
Raises:
ValueError: if this method was called under non-Linux platform.
UserInputError: if user has provided invalid input multiple times.
"""
def set_tf_tensorrt_version(environ_cp):
"""Set TF_TENSORRT_VERSION."""
if not is_linux():
raise ValueError('Currently TensorRT is only supported on Linux platform.')
# Ask user whether to add TensorRT support.
if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
False))) != '1':
if not int(environ_cp.get('TF_NEED_TENSORRT', False)):
return
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
ask_tensorrt_path = (r'Please specify the location where TensorRT is '
'installed. [Default is %s]:') % (
_DEFAULT_TENSORRT_PATH_LINUX)
trt_install_path = get_from_env_or_user_or_default(
environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
_DEFAULT_TENSORRT_PATH_LINUX)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
def find_libs(search_path):
"""Search for libnvinfer.so in "search_path"."""
fl = set()
if os.path.exists(search_path) and os.path.isdir(search_path):
fl.update([
os.path.realpath(os.path.join(search_path, x))
for x in os.listdir(search_path)
if 'libnvinfer.so' in x
])
return fl
possible_files = find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
highest_ver = [0, None, None]
for lib_file in possible_files:
if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
matches = nvinfer_pattern.search(lib_file)
if not matches.groups():
continue
ver_str = matches.group(1)
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
if ver > highest_ver[0]:
highest_ver = [ver, ver_str, lib_file]
if highest_ver[1] is not None:
trt_install_path = os.path.dirname(highest_ver[2])
tf_tensorrt_version = highest_ver[1]
break
# Try another alternative from ldconfig.
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
ldconfig_output = run_shell([ldconfig_bin, '-p'])
search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
ldconfig_output)
if search_result:
libnvinfer_path_from_ldconfig = search_result.group(2)
if os.path.exists(libnvinfer_path_from_ldconfig):
if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
cudnn_ver):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
tf_tensorrt_version = search_result.group(1)
break
# Reset and Retry
if possible_files:
print('TensorRT libraries found in one the following directories',
'are not compatible with selected cuda and cudnn installations')
print(trt_install_path)
print(os.path.join(trt_install_path, 'lib'))
print(os.path.join(trt_install_path, 'lib64'))
if search_result:
print(libnvinfer_path_from_ldconfig)
else:
print(
'Invalid path to TensorRT. None of the following files can be found:')
print(trt_install_path)
print(os.path.join(trt_install_path, 'lib'))
print(os.path.join(trt_install_path, 'lib64'))
if search_result:
print(libnvinfer_path_from_ldconfig)
else:
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
ask_tensorrt_version = (
'Please specify the TensorRT version you want to use. '
'[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION
tf_tensorrt_version = get_from_env_or_user_or_default(
environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version,
_DEFAULT_TENSORRT_VERSION)
environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
def set_tf_nccl_install_path(environ_cp):
"""Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION.
Args:
environ_cp: copy of the os.environ.
Raises:
ValueError: if this method was called under non-Linux platform.
UserInputError: if user has provided invalid input multiple times.
"""
def set_tf_nccl_version(environ_cp):
"""Set TF_NCCL_VERSION."""
if not is_linux():
raise ValueError('Currently NCCL is only supported on Linux platforms.')
raise ValueError('Currently NCCL is only supported on Linux platform.')
if 'TF_NCCL_VERSION' in environ_cp:
return
ask_nccl_version = (
'Please specify the locally installed NCCL version you want to use. '
'[Default is to use https://github.com/nvidia/nccl]: ')
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
'[Leave empty to use http://github.com/nvidia/nccl]: ')
tf_nccl_version = get_from_env_or_user_or_default(environ_cp,
'TF_NCCL_VERSION',
ask_nccl_version, '')
if not tf_nccl_version:
break # No need to get install path, building the open source code.
tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
# Look with ldconfig first if we can find the library in paths
# like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
# include directory. This is where the NCCL .deb packages install them.
# First check to see if NCCL is in the ldconfig.
# If its found, use that location.
if is_linux():
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)',
nccl2_path_from_ldconfig)
if nccl2_path_from_ldconfig:
nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1)
if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)):
nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig)
print('NCCL libraries found in ' + nccl2_path_from_ldconfig)
# Check if this is the main system lib location
if re.search('.*linux-gnu', nccl_install_path):
trunc_nccl_install_path = '/usr'
print('This looks like a system path.')
else:
trunc_nccl_install_path = nccl_install_path + '/..'
# Look for header
nccl_hdr_path = trunc_nccl_install_path + '/include'
print('Assuming NCCL header path is ' + nccl_hdr_path)
if os.path.exists(nccl_hdr_path + '/nccl.h'):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
# Set NCCL_HDR_PATH
environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path
write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path)
break
else:
print(
'The header for NCCL2 cannot be found. Please install the libnccl-dev package.'
)
else:
print('NCCL2 is listed by ldconfig but the library is not found. '
'Your ldconfig is out of date. Please run sudo ldconfig.')
else:
# NCCL is not found in ldconfig. Ask the user for the location.
default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_nccl_path = (
r'Please specify the location where NCCL %s library is '
'installed. Refer to README.md for more details. [Default '
'is %s]:') % (tf_nccl_version, default_nccl_path)
nccl_install_path = get_from_env_or_user_or_default(
environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
nccl_install_path = os.path.realpath(
os.path.expanduser(nccl_install_path))
if is_windows() or is_cygwin():
nccl_install_path = cygpath(nccl_install_path)
nccl_lib_path = ''
if is_windows():
nccl_lib_path = 'lib/x64/nccl.lib'
elif is_linux():
nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version
nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename)
if not os.path.exists(nccl_lpath):
for relative_path in NCCL_LIB_PATHS:
path = '%s/%s%s' % (nccl_install_path, relative_path,
nccl_lib_filename)
if os.path.exists(path):
print('NCCL found at ' + path)
nccl_lib_path = path
break
else:
nccl_lib_path = nccl_lpath
elif is_macos():
nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
nccl_hdr_path = os.path.join(
os.path.dirname(nccl_lib_path), '../include/nccl.h')
print('Assuming NCCL header path is ' + nccl_hdr_path)
if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path)
write_action_env_to_bazelrc('NCCL_INSTALL_PATH',
os.path.dirname(nccl_lib_path))
# Set NCCL_HDR_PATH
environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path)
write_action_env_to_bazelrc('NCCL_HDR_PATH',
os.path.dirname(nccl_hdr_path))
break
# Reset and Retry
print(
'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
'O/S agnostic package of NCCL 2' %
(tf_nccl_version, nccl_lib_path, nccl_hdr_path))
environ_cp['TF_NCCL_VERSION'] = ''
else:
raise UserInputError('Invalid TF_NCCL setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set TF_NCCL_VERSION
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@ -1644,6 +1307,66 @@ def configure_ios():
symlink_force(filepath, new_filepath)
def validate_cuda_config(environ_cp):
"""Run find_cuda_config.py and return cuda_toolkit_path, or None."""
def maybe_encode_env(env):
"""Encodes unicode in env to str on Windows python 2.x."""
if not is_windows() or sys.version_info[0] != 2:
return env
for k, v in env.items():
if isinstance(k, unicode):
k = k.encode('ascii')
if isinstance(v, unicode):
v = v.encode('ascii')
env[k] = v
return env
cuda_libraries = ['cuda', 'cudnn']
if is_linux():
if 'TF_TENSORRT_VERSION' in environ_cp: # if env variable exists
cuda_libraries.append('tensorrt')
if environ_cp.get('TF_NCCL_VERSION', None): # if env variable not empty
cuda_libraries.append('nccl')
proc = subprocess.Popen(
[environ_cp['PYTHON_BIN_PATH'], 'third_party/gpus/find_cuda_config.py'] +
cuda_libraries,
stdout=subprocess.PIPE,
env=maybe_encode_env(environ_cp))
if proc.wait():
# Errors from find_cuda_config.py were sent to stderr.
print('Asking for detailed CUDA configuration...\n')
return False
config = dict(
tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout)
print('Found CUDA %s in:' % config['cuda_version'])
print(' %s' % config['cuda_library_dir'])
print(' %s' % config['cuda_include_dir'])
print('Found cuDNN %s in:' % config['cudnn_version'])
print(' %s' % config['cudnn_library_dir'])
print(' %s' % config['cudnn_include_dir'])
if 'tensorrt_version' in config:
print('Found TensorRT %s in:' % config['tensorrt_version'])
print(' %s' % config['tensorrt_library_dir'])
print(' %s' % config['tensorrt_include_dir'])
if config.get('nccl_version', None):
print('Found NCCL %s in:' % config['nccl_version'])
print(' %s' % config['nccl_library_dir'])
print(' %s' % config['nccl_include_dir'])
print('\n')
environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path']
return True
def main():
global _TF_WORKSPACE_ROOT
global _TF_BAZELRC
@ -1664,7 +1387,7 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version('0.22.0', '0.24.1')
current_bazel_version = check_bazel_version('0.24.1', '0.24.1')
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1720,11 +1443,39 @@ def main():
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
environ_save = dict(environ_cp)
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
if validate_cuda_config(environ_cp):
cuda_env_names = [
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
'CUDA_TOOLKIT_PATH'
]
for name in cuda_env_names:
if name in environ_cp:
write_action_env_to_bazelrc(name, environ_cp[name])
break
# Restore settings changed below if CUDA config could not be validated.
environ_cp = dict(environ_save)
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
if is_linux():
set_tf_tensorrt_install_path(environ_cp)
set_tf_nccl_install_path(environ_cp)
set_tf_tensorrt_version(environ_cp)
set_tf_nccl_version(environ_cp)
set_tf_cuda_paths(environ_cp)
else:
raise UserInputError(
'Invalid CUDA setting were provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
set_tf_cuda_compute_capabilities(environ_cp)
if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(

View File

@ -420,6 +420,9 @@ config_setting(
values = {"cpu": "x64_windows"},
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group(
name = "internal",
packages = [

View File

@ -12,7 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
"""
Top-level module of TensorFlow. By convention, we refer to this module as
`tf` instead of `tensorflow`, following the common practice of importing
TensorFlow via the command `import tensorflow as tf`.
The primary function of this module is to import all of the public TensorFlow
interfaces into a single place. The interfaces themselves are located in
sub-modules, as described below.
Note that the file `__init__.py` in the TensorFlow source code tree is actually
only a placeholder to enable test cases to run. The TensorFlow build replaces
this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py)
"""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
@ -20,10 +32,13 @@ from __future__ import print_function as _print_function
import distutils as _distutils
import inspect as _inspect
import logging as _logging
import os as _os
import site as _site
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
# Make sure directory containing top level submodules is in
@ -37,25 +52,29 @@ if not hasattr(_current_module, '__path__'):
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# pylint: disable=g-bad-import-order
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorboard.summary._tf.summary'),
error_msg="Limited tf.summary API due to missing TensorBoard installation")
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v2.estimator'))
# Hook external TensorFlow modules.
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
if not hasattr(_current_module, 'estimator'):
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v2.keras'))
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top

View File

@ -26,24 +26,37 @@ import sys as _sys
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v1.estimator'))
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_API_MODULE = bitwise # pylint: disable=undefined-variable
_current_module = _sys.modules[__name__]
if not hasattr(_current_module, 'estimator'):
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v1.keras'))
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
@ -66,17 +79,6 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
# The 'app' module will be imported as part of the placeholder section above.
app.flags = flags # pylint: disable=undefined-variable
# Also use 'app' module (choice is arbitrary) to derive the API directory below.
_API_MODULE = app # pylint: disable=undefined-variable
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
@ -118,7 +120,11 @@ if _running_from_pip_package():
# pylint: disable=undefined-variable
try:
del python
if '__all__' in vars():
vars()['__all__'].remove('python')
del core
if '__all__' in vars():
vars()['__all__'].remove('core')
except NameError:
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
@ -129,6 +135,8 @@ except NameError:
# others don't exist.
try:
del compiler
if '__all__' in vars():
vars()['__all__'].remove('compiler')
except NameError:
pass
# pylint: enable=undefined-variable

View File

@ -104,6 +104,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"@com_google_absl//absl/strings",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",
@ -145,6 +146,7 @@ tf_cuda_library(
"//tensorflow/core:lib_platform",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api_internal.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -352,6 +352,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
argdef->set_type(node->output_type(idx));
const string& input_name = node_names.GetInputName(node->name());
argdef->set_name(input_name);
auto& arg_attrs = (*fdef->mutable_arg_attr())[i];
for (const auto& attr : node->attrs()) {
// Only copy internal attributes. These attributes will be applied to
// _Arg/Placeholder nodes when this FunctionDef is converted to graph, and
// normal attributes for nodes cannot be applied to those _Arg/Placeholder
// nodes.
if (absl::StartsWith(attr.first, "_")) {
arg_attrs.mutable_attr()->insert(attr);
}
}
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}

View File

@ -1278,6 +1278,46 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
}
void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
const char* attr_name, const char* attr_value,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
TF_NewGraph(), TF_DeleteGraph);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
TF_DeleteStatus);
TF_Operation* node;
NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
&node);
TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
ASSERT_NE(func_, nullptr);
// Verify that FunctionDef ArgDef has attributes.
ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
auto arg_attrs = func_->fdef.arg_attr().find(0);
ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
auto iter = arg_attrs->second.attr().find("_test_attr");
ASSERT_NE(iter, arg_attrs->second.attr().end());
EXPECT_EQ(iter->second.s(), "value");
}
TEST_F(CApiFunctionTest, SetGradientAndRun) {
// Define the function and its grad
DefineFunction(func_name_, &func_);

View File

@ -29,8 +29,7 @@ namespace checkpoint {
class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename,
TF_Status* out_status)
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
: reader_(nullptr),
v2_reader_(nullptr),
var_to_shape_map_(nullptr),
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
v2_reader_.reset(
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
if (!v2_reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, v2_reader_->status());
Set_TF_Status_from_Status(status, v2_reader_->status());
return;
}
auto result = BuildV2VarMaps();
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
} else {
reader_.reset(new TensorSliceReader(filename));
if (!reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, reader_->status());
Set_TF_Status_from_Status(status, reader_->status());
return;
}
var_to_shape_map_.reset(

View File

@ -39,7 +39,7 @@ class TensorSliceReader;
// variables.
class CheckpointReader {
public:
CheckpointReader(const string& filepattern, TF_Status* out_status);
CheckpointReader(const string& filename, TF_Status* status);
bool HasTensor(const string& name) const;
const string DebugString() const;

View File

@ -70,6 +70,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/lib:profiler_eager_lib",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
@ -110,6 +111,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_eager_lib",
"//tensorflow/core/profiler/lib:profiler_session",
],
)
@ -200,6 +202,7 @@ tf_cuda_library(
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@ -236,7 +239,6 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler:protos_all_cc",
"@com_google_absl//absl/strings",
],
)

18
tensorflow/c/eager/c_api.cc Executable file → Normal file
View File

@ -21,6 +21,9 @@ limitations under the License.
#include <string>
#include <vector>
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h" // NOLINT
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@ -38,11 +41,15 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@ -88,6 +95,7 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
#if !defined(IS_MOBILE_PLATFORM)
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
@ -143,7 +151,9 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
tensorflow::eager::EagerClient* eager_client;
TF_RETURN_IF_ERROR(
remote_eager_workers->GetClient(remote_worker, &eager_client));
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
@ -243,6 +253,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
keep_alive_secs);
#undef LOG_AND_RETURN_IF_ERROR
}
#endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) {
@ -396,6 +407,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
const void* proto,
size_t proto_len,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::errors::Unimplemented(
"TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
@ -404,6 +419,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
}
status->status =
UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
#endif // !IS_MOBILE_PLATFORM
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
@ -98,80 +99,422 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
return s.ok();
}
static tensorflow::mutex gauges_map_lock(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<string,
tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
get_gauges_map() EXCLUSIVE_LOCKS_REQUIRED(gauges_map_lock) {
static std::unordered_map<
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
gauges_map = new std::unordered_map<
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>;
return gauges_map;
}
static tensorflow::mutex counters_map_lock(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<string, tensorflow::monitoring::Counter<1>*>*
get_counters_map() EXCLUSIVE_LOCKS_REQUIRED(counters_map_lock) {
static std::unordered_map<string, tensorflow::monitoring::Counter<1>*>*
counters_map =
new std::unordered_map<string, tensorflow::monitoring::Counter<1>*>;
return counters_map;
}
static tensorflow::mutex samplers_map_lock(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
get_samplers_map() EXCLUSIVE_LOCKS_REQUIRED(samplers_map_lock) {
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
samplers_map =
new std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>;
return samplers_map;
}
void TFE_MonitoringSetGauge(const char* name, const char* label,
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
tensorflow::mutex_lock l(gauges_map_lock);
auto gauges_map = get_gauges_map();
if (gauges_map->find(name) == gauges_map->end()) {
gauges_map->emplace(
name, tensorflow::monitoring::Gauge<tensorflow::int64, 1>::New(
name,
tensorflow::strings::StrCat(
name, " :Gauge metric collected from Python API."),
"metric_descriptor"));
}
gauges_map->at(name)->GetCell(label)->Set(value);
cell->cell.IncrementBy(value);
}
void TFE_MonitoringAddCounter(const char* name, const char* label,
int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
return cell->cell.value();
}
TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringCounter0({name, description});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell()));
}
TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringCounter1({name, description, label1});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1)));
}
TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringCounter2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1, label2)));
}
void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
int64_t value) {
tensorflow::mutex_lock l(counters_map_lock);
auto counters_map = get_counters_map();
if (counters_map->find(name) == counters_map->end()) {
counters_map->emplace(
name, tensorflow::monitoring::Counter<1>::New(
name,
tensorflow::strings::StrCat(
name, " :Counter metric collected from Python API."),
"metric_descriptor"));
}
counters_map->at(name)->GetCell(label)->IncrementBy(value);
cell->cell.Set(value);
}
void TFE_MonitoringAddSampler(const char* name, const char* label,
int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringIntGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
TFE_MonitoringIntGauge0* gauge) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
TFE_MonitoringIntGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringIntGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
const char* value) {
cell->cell.Set({value});
}
const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
tensorflow::string value = cell->cell.value();
void* data = tensorflow::port::Malloc(value.length());
value.copy(static_cast<char*>(data), value.length(), 0);
buf->data = data;
buf->length = value.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* status, const char* description) {
auto* result = new TFE_MonitoringStringGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
TFE_MonitoringStringGauge0* gauge) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* status, const char* description,
const char* label1) {
auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
TFE_MonitoringStringGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2) {
auto* result =
new TFE_MonitoringStringGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
bool value) {
cell->cell.Set(value);
}
bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringBoolGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
TFE_MonitoringBoolGauge0* gauge) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
TFE_MonitoringBoolGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringBoolGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
double value) {
tensorflow::mutex_lock l(samplers_map_lock);
auto samplers_map = get_samplers_map();
if (samplers_map->find(name) == samplers_map->end()) {
samplers_map->emplace(
name, tensorflow::monitoring::Sampler<1>::New(
{name,
tensorflow::strings::StrCat(
name, " :Counter metric collected from Python API."),
"metric_descriptor"},
{tensorflow::monitoring::Buckets::Exponential(1, 2, 30)}));
cell->cell.Add(value);
}
samplers_map->at(name)->GetCell(label)->Add(value);
void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
TF_Buffer* buf) {
string content;
cell->cell.value().SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
buf->data = data;
buf->length = content.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
double growth_factor,
int bucket_count) {
return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
bucket_count);
});
}
void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
delete buckets;
}
TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringSampler0(
{name, buckets->create_buckets(), description});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell()));
}
TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1) {
auto* result = new TFE_MonitoringSampler1(
{name, buckets->create_buckets(), description, label1});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1)));
}
TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1, const char* label2) {
auto* result = new TFE_MonitoringSampler2(
{name, buckets->create_buckets(), description, label1, label2});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
}

View File

@ -87,24 +87,228 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts);
// Set the value of a Gauge metric. If the metric with given name does not
// exist, it will create a new Gauge metric. Right now it only supports type
// int64, consider to add more type supports if needed.
TF_CAPI_EXPORT extern void TFE_MonitoringSetGauge(const char* name,
const char* label,
int64_t value);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
// These APIs de-templated monitoring Counter for swig.
// Increase a Counter metric by the given value. If the metric with given name
// does not exist, it will create a new Counter metric.
TF_CAPI_EXPORT extern void TFE_MonitoringAddCounter(const char* name,
const char* label,
int64_t value);
typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell;
// Add the given value to a Sampler metric. If the metric with given name
// does not exist, it will create a new Sampler metric.
TF_CAPI_EXPORT extern void TFE_MonitoringAddSampler(const char* name,
const char* label,
double value);
// Atomically increments the value of the cell. The value must be non-negative.
TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy(
TFE_MonitoringCounterCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue(
TFE_MonitoringCounterCell* cell);
// APIs for Counter without label.
typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0;
// Returns a new Counter metric object. The caller should manage lifetime of
// the object. Using duplicate metric name will crash the program with fatal
// error.
TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(
const char* name, TF_Status* status, const char* description);
// Deletes the Counter object.
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0(
TFE_MonitoringCounter0* counter);
// Retrieves the cell from the Counter object. The Counter object will manage
// lifetime of the cell.
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter);
// APIs for Counter with 1 label.
typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1;
TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(
const char* name, TF_Status* status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1(
TFE_MonitoringCounter1* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1);
// APIs for Counter with 2 labels.
typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2;
TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
TFE_MonitoringCounter2* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Gauge APIs.
// These APIs de-templated monitoring Gauge for swig.
typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell;
// Atomically set the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet(
TFE_MonitoringIntGaugeCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue(
TFE_MonitoringIntGaugeCell* cell);
// APIs for Int Gauge without label.
typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0(
TFE_MonitoringIntGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge);
// APIs for Int Gauge with 1 label.
typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1(
TFE_MonitoringIntGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge,
const char* label1);
// APIs for Int Gauge with 2 label.
typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2(
TFE_MonitoringIntGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet(
TFE_MonitoringStringGaugeCell* cell, const char* value);
// Retrieves the string value and saves it in buffer.
TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf);
// APIs for String Gauge without label.
typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0(
TFE_MonitoringStringGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge);
// APIs for String Gauge with 1 label.
typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1(
TFE_MonitoringStringGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge,
const char* label1);
// APIs for String Gauge with 2 label.
typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2(
TFE_MonitoringStringGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet(
TFE_MonitoringBoolGaugeCell* cell, bool value);
TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue(
TFE_MonitoringBoolGaugeCell* cell);
// APIs for Bool Gauge without label.
typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0(
TFE_MonitoringBoolGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge);
// APIs for Bool Gauge with 1 label.
typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1(
TFE_MonitoringBoolGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge,
const char* label1);
// APIs for Bool Gauge with 2 label.
typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2(
TFE_MonitoringBoolGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge,
const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Sampler APIs.
// These APIs de-templated monitoring Sampler for swig.
typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell;
// Atomically add the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd(
TFE_MonitoringSamplerCell* cell, double value);
// Retrieves the current value of the cell. The return value is a HistogramProto
// saved in buffer.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue(
TFE_MonitoringSamplerCell* cell, TF_Buffer* buf);
// APIs for sampler buckets
typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets;
TF_CAPI_EXPORT extern TFE_MonitoringBuckets*
TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor,
int bucket_count);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets(
TFE_MonitoringBuckets* buckets);
// APIs for Sampler without label.
typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0;
TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0(
TFE_MonitoringSampler0* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler);
// APIs for Sampler with 1 label.
typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1;
TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1(
TFE_MonitoringSampler1* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1);
// APIs for Sampler with 2 label.
typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2;
TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
TFE_MonitoringSampler2* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
#ifdef __cplusplus
} /* end extern "C" */

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include <string.h>
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
@ -24,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/profiler/trace_events.pb.h"
#include "tensorflow/core/protobuf/trace_events.pb.h"
using tensorflow::string;
@ -80,11 +81,15 @@ void ExecuteWithProfiling(bool async) {
profiler_result->length}));
string profile_proto_str = profile_proto.DebugString();
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
// device name with "stream:all" is collected by Device Tracer.
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
// TODO(fishx): move following check out from this if statement.
// This is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
}
EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:CPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
TF_DeleteBuffer(profiler_result);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
@ -126,25 +131,15 @@ TEST(CAPI, MultipleProfilerSession) {
TFE_DeleteProfilerContext(profiler_context);
}
TEST(CAPI, MonitoringSetGauge) {
TFE_MonitoringSetGauge("test/gauge", "label", 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
EXPECT_EQ(1,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringSetGauge("test/gauge", "label", 5);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(5,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
}
TEST(CAPI, MonitoringAddCounter) {
TFE_MonitoringAddCounter("test/counter", "label", 1);
TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus();
auto* counter =
TFE_MonitoringNewCounter0("test/counter", status, "description");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell = TFE_MonitoringGetCellCounter0(counter);
TFE_MonitoringCounterCellIncrementBy(cell, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
@ -155,14 +150,92 @@ TEST(CAPI, MonitoringAddCounter) {
EXPECT_EQ(
1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringAddCounter("test/counter", "label", 5);
TFE_MonitoringCounterCellIncrementBy(cell, 5);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(
6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringDeleteCounter0(counter);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(metrics->point_set_map.end(),
metrics->point_set_map.find("test/counter"));
}
TEST(CAPI, MonitoringAddSampler) {
TFE_MonitoringAddSampler("test/sampler", "label", 1.0);
TEST(CAPI, MonitoringCounterMultiple) {
TF_Status* status = TF_NewStatus();
auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status,
"description", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test");
TFE_MonitoringCounterCellIncrementBy(cell1, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1);
auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status,
"description", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar");
TFE_MonitoringCounterCellIncrementBy(cell2, 2);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2);
TFE_MonitoringDeleteCounter1(counter1);
TFE_MonitoringDeleteCounter2(counter2);
}
TEST(CAPI, MonitoringGauge0) {
TF_Status* status = TF_NewStatus();
auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellIntGauge0(gauge);
TFE_MonitoringIntGaugeCellSet(cell, 1);
EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
EXPECT_EQ(1,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringIntGaugeCellSet(cell, 5);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(5,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleGauge) {
TF_Status* status = TF_NewStatus();
auto* gauge1 =
TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo");
TFE_MonitoringBoolGaugeCellSet(cell1, true);
EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1));
auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test",
"label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar");
TFE_MonitoringStringGaugeCellSet(cell2, "str");
auto* buf = new TF_Buffer;
TFE_MonitoringStringGaugeCellValue(cell2, buf);
string data(static_cast<const char*>(buf->data), buf->length);
delete buf;
EXPECT_EQ(data, "str");
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringSampler0) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler =
TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellSampler0(sampler);
TFE_MonitoringSamplerCellAdd(cell, 1.0);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
@ -174,11 +247,48 @@ TEST(CAPI, MonitoringAddSampler) {
->points.at(0)
->histogram_value.sum());
TFE_MonitoringAddSampler("test/sampler", "label", 5.0);
TFE_MonitoringSamplerCellAdd(cell, 5.0);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
->points.at(0)
->histogram_value.sum());
TFE_MonitoringDeleteBuckets(buckets);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleSampler) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status,
"test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo");
TFE_MonitoringSamplerCellAdd(cell1, 1.0);
TFE_MonitoringSamplerCellAdd(cell1, 2.0);
TF_Buffer* result1 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell1, result1);
tensorflow::HistogramProto hitogram1;
EXPECT_TRUE(hitogram1.ParseFromString(
{reinterpret_cast<const char*>(result1->data), result1->length}));
EXPECT_EQ(hitogram1.sum(), 3.0);
delete result1;
auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status,
"test", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar");
TFE_MonitoringSamplerCellAdd(cell2, 2.0);
TFE_MonitoringSamplerCellAdd(cell2, 3.0);
TF_Buffer* result2 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell2, result2);
tensorflow::HistogramProto hitogram2;
EXPECT_TRUE(hitogram2.ParseFromString(
{reinterpret_cast<const char*>(result2->data), result2->length}));
EXPECT_EQ(hitogram2.sum(), 5.0);
delete result2;
TFE_MonitoringDeleteBuckets(buckets);
TF_DeleteStatus(status);
}
} // namespace

View File

@ -15,8 +15,6 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#include "tensorflow/c/eager/c_api.h"
#include <algorithm>
#include <cstddef>
#include <map>
@ -28,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -37,19 +36,14 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
@ -133,6 +127,124 @@ struct TFE_Profiler {
std::unique_ptr<tensorflow::ProfilerSession> profiler;
};
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};
template <int NumLabels>
struct TFE_MonitoringCounter {
template <typename... LabelDesc>
TFE_MonitoringCounter(const char* name, const char* description,
LabelDesc&&... label) {
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
};
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringIntGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
};
struct TFE_MonitoringStringGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
};
struct TFE_MonitoringBoolGaugeCell {
tensorflow::monitoring::GaugeCell<bool> cell;
};
template <typename ValueType, int NumLabels>
struct TFE_MonitoringGauge {
template <typename... LabelDesc>
TFE_MonitoringGauge(const char* name, const char* description,
LabelDesc&&... label) {
gauge = absl::WrapUnique(
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
};
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBuckets {
TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
}
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
create_buckets;
};
struct TFE_MonitoringSamplerCell {
tensorflow::monitoring::SamplerCell cell;
};
template <int NumLabels>
struct TFE_MonitoringSampler {
template <typename... LabelDesc>
TFE_MonitoringSampler(
const char* name,
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
const char* description, LabelDesc&&... label) {
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
{name, description, label...}, std::move(buckets)));
}
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
};
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
namespace tensorflow {
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,

View File

@ -42,15 +42,20 @@ namespace {
const int kRightMargin = 79;
// Converts:
// bazel-out/.../genfiles/(external/YYY/)?XX
// bazel-out/.../(bin|genfiles)/(external/YYY/)?XX
// to: XX.
string GetPath(const string& dot_h_fname) {
auto pos = dot_h_fname.find("/genfiles/");
auto pos = dot_h_fname.find("/bin/");
string result = dot_h_fname;
if (pos != string::npos) {
// - 1 account for the terminating null character (\0) in "/genfiles/".
result = dot_h_fname.substr(pos + sizeof("/bin/") - 1);
} else {
pos = dot_h_fname.find("/genfiles/");
if (pos != string::npos) {
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
}
}
if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
result = result.substr(sizeof("external/") - 1);

View File

@ -18,27 +18,41 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
import logging as _logging
import os as _os
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# pylint: disable=g-bad-import-order
# API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorboard.summary._tf.summary'),
error_msg=(
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
except ImportError:
_logging.warning(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
"installation"))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v2.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v2.keras'))
"installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
#

View File

@ -19,18 +19,30 @@ from __future__ import division as _division
from __future__ import print_function as _print_function
import os as _os
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# pylint: disable=g-bad-import-order
# API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v1.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v1.keras'))
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable

View File

@ -163,7 +163,10 @@ def tf_library(
header_file = name + ".h"
metadata_object_file = name + "_tfcompile_metadata.o"
function_object_file = name + "_tfcompile_function.o"
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
# The XLA backends morph kernal name prefix __ that is not in the form of
# __xla_.
ep = ("__xla_" + native.package_name() + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:
@ -171,6 +174,20 @@ def tf_library(
"'" + arg.replace("'", "'\\''") + "'"
for arg in (tfcompile_flags or [])
])
# Do this before we append the `select` into `flags`, because doing so
# transforms `flags` into a variable of type `select`, and we can't call
# `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
# Pass --target_cpu=haswell to tfcompile if compiling for Haswell (bazel
# build --cpu=haswell). We put it at the beginning of the flags list so
# that tfcompile_flags can override if if desired.
flags = select({
"//tools/target_cpu:haswell": "--target_cpu=haswell ",
"//conditions:default": "",
}) + flags
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
else:
@ -248,7 +265,6 @@ def tf_library(
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library(
name = name,
srcs = [function_object_file, metadata_object_file],

View File

@ -321,6 +321,7 @@ cc_library(
deps = [
":compilation_passes",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/compiler/tf2xla:rearrange_function_argument_pass_registration",
"//tensorflow/core:core_cpu_internal",
],
alwayslink = 1,
@ -371,6 +372,7 @@ cc_library(
srcs = ["resource_operation_safety_analysis.cc"],
hdrs = ["resource_operation_safety_analysis.h"],
deps = [
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/core:framework",
@ -521,6 +523,7 @@ cc_library(
":device_info_cache",
":encapsulate_util",
":flags",
":resource_operation_safety_analysis",
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
@ -565,7 +568,6 @@ cc_library(
hdrs = ["xla_cluster_util.h"],
deps = [
":flags",
":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -660,6 +662,7 @@ tf_cc_test(
"introduce_floating_point_jitter_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
"rearrange_function_argument_pass_test.cc",
],
deps = [
":common",
@ -680,6 +683,7 @@ tf_cc_test(
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:rearrange_function_argument_pass",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -732,44 +736,6 @@ tf_cc_test(
],
)
cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
hdrs = ["xla_fusion_optimizer.h"],
visibility = ["//visibility:public"],
deps = [
":common",
":compilation_passes",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "xla_fusion_optimizer_test",
srcs = ["xla_fusion_optimizer_test.cc"],
deps = [
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/utils:grappler_test",
],
)
cc_library(
name = "node_matchers",
testonly = True,

View File

@ -14,7 +14,10 @@ cc_library(
hdrs = ["graphcycles.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
],
)

View File

@ -34,7 +34,10 @@ limitations under the License.
#include <algorithm>
#include <unordered_set>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@ -402,4 +405,53 @@ std::unordered_set<int32> GraphCycles::Predecessors(int32 node) const {
return rep_->nodes_[node]->in;
}
namespace {
void SortInPostOrder(absl::Span<Node* const> nodes,
std::vector<int32>* to_sort) {
absl::c_sort(*to_sort, [&](int32 a, int32 b) {
DCHECK(a == b || nodes[a]->rank != nodes[b]->rank);
return nodes[a]->rank > nodes[b]->rank;
});
}
} // namespace
std::vector<int32> GraphCycles::AllNodesInPostOrder() const {
absl::flat_hash_set<int32> free_nodes_set;
absl::c_copy(rep_->free_nodes_,
std::inserter(free_nodes_set, free_nodes_set.begin()));
std::vector<int32> all_nodes;
all_nodes.reserve(rep_->nodes_.size() - free_nodes_set.size());
for (int64 i = 0, e = rep_->nodes_.size(); i < e; i++) {
if (!free_nodes_set.contains(i)) {
all_nodes.push_back(i);
}
}
SortInPostOrder(rep_->nodes_, &all_nodes);
return all_nodes;
}
string GraphCycles::DebugString() const {
absl::flat_hash_set<int32> free_nodes_set;
for (int32 free_node : rep_->free_nodes_) {
free_nodes_set.insert(free_node);
}
string result = "digraph {\n";
for (int i = 0; i < rep_->nodes_.size(); i++) {
if (free_nodes_set.contains(i)) {
continue;
}
for (int32 succ : rep_->nodes_[i]->out) {
absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n");
}
}
absl::StrAppend(&result, "}\n");
return result;
}
} // namespace tensorflow

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#include <vector>
// GraphCycles detects the introduction of a cycle into a directed
// graph that is being built up incrementally.
//
@ -120,6 +122,15 @@ class GraphCycles {
std::unordered_set<int32> Successors(int32 node) const;
std::unordered_set<int32> Predecessors(int32 node) const;
// Returns all nodes in post order.
//
// If there is a path from X to Y then X appears after Y in the
// returned vector.
std::vector<int32> AllNodesInPostOrder() const;
// Returns the graph in graphviz format.
string DebugString() const;
// ----------------------------------------------------
struct Rep;

View File

@ -39,6 +39,10 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
// FunctionalizeControlFlowPass: 27
//
// from
// third_party/tensorflow/compiler/tf2xla/rearrange_function_argument_pass_registration.cc
// RearrangeFunctionArgumentPass: 28
//
// This pass looks at the graph and all associated FunctionDefs, and turns
// traditional control flow structure (Switch/Merge/etc.) into functional
// control flow structure (XlaIf/XlaWhile). Following passes must

File diff suppressed because it is too large Load Diff

View File

@ -270,11 +270,11 @@ TEST(XlaCompilationTest, FunctionCalls) {
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
EXPECT_FALSE(clusters["B"].empty());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_FALSE(clusters["C"].empty());
EXPECT_EQ(clusters["C"], clusters["E"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("B") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
EXPECT_TRUE(clusters.find("E") == clusters.cend());
}
TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
@ -332,31 +332,6 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
EXPECT_NE(clusters["A"], "");
}
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
// cluster. This is partially to work around b/26800664, and partially because
// we should probably prefer to compile metadata operators with their producers
// wherever possible, rather than their consumers.
TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
// While all of the following ops are notionally compilable, none is
// permitted
// to start a cluster. So nothing should be compiled.
Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
static Status GradForUnaryCwise(FunctionDef* g,
std::vector<FunctionDefHelper::Node> nodes) {
for (auto& n : nodes) {
@ -774,7 +749,7 @@ TEST(XlaCompilationTest, ChainOfOps) {
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes_a = {
"AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
"AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"};
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
}

View File

@ -340,6 +340,46 @@ Status PartiallyDeclusterGraph(Graph* graph,
return Status::OK();
}
} // namespace reduce_recompilation
namespace decluster_root_shape_consumers {
// Returns true if `node` an operator that consumes only the shape of its input,
// not the data itself.
bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
node.type_string() == "Size";
}
Status PartiallyDeclusterGraph(Graph* graph) {
std::vector<Node*> reverse_post_order;
GetReversePostOrder(*graph, &reverse_post_order,
/*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
for (Node* n : reverse_post_order) {
if (!IsShapeConsumerOp(*n)) {
continue;
}
absl::optional<absl::string_view> cluster = GetXlaClusterForNode(*n);
if (!cluster.has_value()) {
continue;
}
auto input_belongs_to_same_cluster = [&](const Edge* e) {
return cluster == GetXlaClusterForNode(*e->src());
};
if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) {
continue;
}
VLOG(2) << "Declustering " << n->name()
<< " because it is a root shape consumer";
RemoveFromXlaCluster(n);
}
return Status::OK();
}
} // namespace decluster_root_shape_consumers
} // namespace
Status PartiallyDeclusterPass::Run(
@ -367,6 +407,9 @@ Status PartiallyDeclusterPass::Run(
TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(
graph, options.flib_def, options.session_options->env));
TF_RETURN_IF_ERROR(
decluster_root_shape_consumers::PartiallyDeclusterGraph(graph));
return Status::OK();
}
} // namespace tensorflow

View File

@ -467,5 +467,37 @@ TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr);
}
TEST(PartiallyDeclusterPassTest, MetadataOpsDontStartClusters) {
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
Output b = ops::Shape(in_cluster_and.WithOpName("b"), a);
Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
Output d = ops::Size(in_cluster_and.WithOpName("d"), c);
(void)ops::Shape(in_cluster_and.WithOpName("e"), d);
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(PartiallyDecluster(&graph));
Node* n_b = FindNodeByName(*graph, "b");
ASSERT_NE(n_b, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_b), absl::nullopt);
Node* n_c = FindNodeByName(*graph, "c");
ASSERT_NE(n_c, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_c), absl::nullopt);
Node* n_d = FindNodeByName(*graph, "d");
ASSERT_NE(n_d, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_d), absl::nullopt);
Node* n_e = FindNodeByName(*graph, "e");
ASSERT_NE(n_e, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n_e), absl::nullopt);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,214 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
class RearrangeFunctionArgumentForFunctionTest : public ::testing::Test {
public:
void SetUp() override {
SessionOptions session_options;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
}
Status RearrangeFunctionArgumentTest(
const string &func_name, const string &new_func_name,
const protobuf::Map<string, tensorflow::AttrValue> &attrs,
FunctionLibraryDefinition *fld, bool *modified) {
OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
return RearrangeFunctionArgumentForFunction(
func_name, new_func_name, attrs, fld, flr,
&canonicalized_name_to_new_name, modified);
}
private:
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
};
TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
FunctionDefLibrary fdl;
{
// Function for StatefulPartitionedCall's "f", If's
// "then_branch"/"else_branch".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
// "ret0" = "arg1"
// "ret1" = "arg0"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
}
{
// Function for While's "body".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
// "ret0" = "arg0"
// "ret1" = "arg1"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 0);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
}
{
// Function for While's "cond".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
// "ret0" = "arg1"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef));
}
{
// Build the XLA computation func.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
// "arg0", "arg1" -> "call" (StatefulPartitionedCall) -> "ret0", "ret1"
// "arg0", "arg1" -> "if" (If) -> "ret2", "ret3"
// "arg0", "arg1" -> "while" (While) -> "ret4", "ret5"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
NameAttrList f;
f.set_name("f1");
auto call = ops::StatefulPartitionedCall(
s.WithOpName("call"), {arg0, arg1},
std::vector<DataType>{DT_BOOL, DT_RESOURCE}, f);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), call.output[0], 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), call.output[1], 1);
auto if_op = ops::If(s.WithOpName("if"), arg1,
std::initializer_list<Input>{arg0, arg1},
{DT_BOOL, DT_RESOURCE}, f, f);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), if_op.output[0], 2);
auto ret3 = ops::_Retval(s.WithOpName("ret3"), if_op.output[1], 3);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f3");
body_fn.set_name("f2");
auto while_op =
ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
auto ret4 = ops::_Retval(s.WithOpName("ret4"), while_op.output[0], 4);
auto ret5 = ops::_Retval(s.WithOpName("ret5"), while_op.output[1], 5);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
bool modified;
protobuf::Map<string, tensorflow::AttrValue> attrs;
TF_CHECK_OK(RearrangeFunctionArgumentTest("cluster", "cluster_rewritten",
attrs, &fld, &modified));
// Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE}
// and output types should be {DT_BOOL}.
const FunctionDef *f1_rewritten = fld.Find("f1_rearrange_0");
CHECK_NE(f1_rewritten, nullptr);
ASSERT_EQ(f1_rewritten->signature().input_arg_size(), 2);
EXPECT_EQ(f1_rewritten->signature().input_arg(0).type(), DT_BOOL);
EXPECT_EQ(f1_rewritten->signature().input_arg(1).type(), DT_RESOURCE);
ASSERT_EQ(f1_rewritten->signature().output_arg_size(), 1);
EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL);
// Check node "call" input and output edges.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
const Node *call_node = node_name_index.at("call");
ASSERT_NE(call_node, nullptr);
const Node *input_node;
TF_CHECK_OK(call_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg1");
TF_CHECK_OK(call_node->input_node(1, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret0_node = xla_fbody->ret_nodes[0];
TF_CHECK_OK(ret0_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "call");
const Node *ret1_node = xla_fbody->ret_nodes[1];
TF_CHECK_OK(ret1_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
// Check node "if" input and output edges.
const Node *if_node = node_name_index.at("if");
ASSERT_NE(if_node, nullptr);
TF_CHECK_OK(if_node->input_node(1, &input_node));
EXPECT_EQ(input_node->name(), "arg1");
TF_CHECK_OK(if_node->input_node(2, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret2_node = xla_fbody->ret_nodes[2];
TF_CHECK_OK(ret2_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "if");
const Node *ret3_node = xla_fbody->ret_nodes[3];
TF_CHECK_OK(ret3_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
// Check node "while" input and output edges.
const Node *while_node = node_name_index.at("while");
ASSERT_NE(while_node, nullptr);
TF_CHECK_OK(while_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg1");
TF_CHECK_OK(while_node->input_node(1, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret4_node = xla_fbody->ret_nodes[4];
TF_CHECK_OK(ret4_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret5_node = xla_fbody->ret_nodes[5];
TF_CHECK_OK(ret5_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "while");
}
} // namespace tensorflow

View File

@ -84,6 +84,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@ -93,22 +94,6 @@ limitations under the License.
namespace tensorflow {
namespace {
// Returns true if `n` may call a function.
Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
bool* out_result) {
if (flib_def->Contains(n.type_string())) {
*out_result = true;
} else {
*out_result =
std::any_of(n.def().attr().begin(), n.def().attr().end(),
[](const std::pair<string, AttrValue>& name_attr_pair) {
return name_attr_pair.second.has_func();
});
}
return Status::OK();
}
// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is
// not a resource operation recognized by XLA then sets `out_resource_op_kind`
// to nullopt.
@ -134,9 +119,7 @@ Status XlaResourceOpKindForNode(
// We conservatively assume that functions will both read and write resource
// variables. In the future we may consider doing some form of
// inter-procedural analysis.
bool may_call_function;
TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
if (may_call_function) {
if (MayCallFunction(n, flib_def)) {
*out_resource_op_kind = XlaResourceOpKind::kReadWrite;
} else {
*out_resource_op_kind = absl::nullopt;

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -227,28 +226,6 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
Status AdjustCycleDetectionGraphForResourceOps(
const Graph* graph, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
GraphCycles* cycles) {
std::vector<std::pair<int, int>> unsafe_deps;
TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
*graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
// An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
// operations that interact with resource variables, must not be put in the
// same cluster. We enforce this constraint by creating a phantom node, X,
// and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P
// and Q together since that would create a cycle with X.
for (std::pair<int, int> unsafe_dep : unsafe_deps) {
int phantom_node_id = cycles->NewNode();
CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
}
return Status::OK();
}
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device,
@ -436,4 +413,16 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
return result;
}
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
if (flib_def->Contains(n.type_string())) {
return true;
}
// This is a conservative check: there may be nodes with a `func`
// attribute that do not make function calls.
return absl::c_any_of(n.def().attr(),
[](const std::pair<string, AttrValue>& name_attr_pair) {
return name_attr_pair.second.has_func();
});
}
} // namespace tensorflow

View File

@ -74,13 +74,6 @@ void RemoveFromXlaCluster(Node* node);
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);
// Adds edges to `cycles` to prevent clustering resource operations that cannot
// be legally clustered.
Status AdjustCycleDetectionGraphForResourceOps(
const Graph* graph, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
GraphCycles* cycles);
// Picks the device for which XLA should compile a cluster that contains
// operations placed in devices in `device_names`. For instance a cluster that
// contains operations solely placed on the CPU will be compiled into a CPU
@ -134,6 +127,10 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
// Returns true if `g` is a single-GPU graph. A single-GPU graph uses exactly
// one GPU (and any number of CPUs).
bool IsSingleGpuGraph(const Graph& g);
// Returns true if it is possible (but not guaranteed) that `n` calls a
// function.
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -30,10 +30,17 @@ namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK();
}
Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -1,352 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include <atomic>
#include <deque>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
namespace tensorflow {
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
static bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
node.type_string() == "Rank" || node.type_string() == "Size";
}
// Returns true if the op can be decomposed into XLA ops for which
// there are fusible elemental implementations.
static bool IsXlaFusible(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
{// tf2xla/kernels/aggregate_ops.cc
"AddN",
// tf2xla/kernels/binary_ops.cc
"Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
"FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
"TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
"GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
"SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
// tf2xla/kernels/unary_ops.cc
"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
"Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
"Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
"Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
"Square", "Tan", "Tanh", "Real", "Imag",
// tf2xla/kernels/bcast_ops.cc
"BroadcastArgs", "BroadcastGradientArgs",
// tf2xla/kernels/bias_ops.cc
"BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
// tf2xla/kernels/cast_op.cc
"Cast",
// tf2xla/kernels/concat_op.cc
"Concat", "ConcatV2", "ConcatOffset",
// tf2xla/kernels/const_op.cc
"Const",
// tf2xla/kernels/elu_op.cc
"Elu", "EluGrad", "Selu", "SeluGrad",
// tf2xla/kernels/fill_op.cc
"Fill",
// tf2xla/kernels/identity_op.cc
"Identity", "IdentityN", "PreventGradient",
"StopGradient", /*"Snapshot",*/
// tf2xla/kernels/index_ops.cc
"ArgMax", "ArgMin",
// tf2xla/kernels/mirror_pad_op.cc
"MirrorPad",
// tf2xla/kernels/one_hot_op.cc
"OneHot",
// tf2xla/kernels/pack_op.cc
"Pack",
// tf2xla/kernels/pad_op.cc
"Pad", "PadV2",
// tf2xla/kernels/relu_op.cc
"Relu", "Relu6", "ReluGrad", "Relu6Grad",
// tf2xla/kernels/reshape_op.cc
"Reshape",
// tf2xla/kernels/reverse_op.cc
"Reverse", "ReverseV2",
// tf2xla/kernels/reverse_sequence_op.cc
"ReverseSequence",
// tf2xla/kernels/shape_op.cc
"Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
"ZerosLike", "OnesLike",
// tf2xla/kernels/slice_op.cc
"Slice",
// tf2xla/kernels/split_op.cc
"Split", "SplitV",
// tf2xla/kernels/strided_slice_op.cc
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
// tf2xla/kernels/tile_ops.cc
"Tile",
// tf2xla/kernels/transpose_op.cc
"Transpose", "InvertPermutation",
// tf2xla/kernels/unpack_op.cc
"Unpack"});
return elementwise_ops->count(node.op()) > 0;
}
Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) {
VLOG(2) << "Here at fusion optimizer";
// TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
// Once that happens, the expected interaction between this optimizer and when
// the global_jit_level is set is as follows: Fusion optimizer will replace
// appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
// be further compiled where possible via mark_for_compilation_pass. Note that
// this might lead to inefficient clustering, and it is best to use either the
// fusion optimizer or the global_jit flag, and not combine the two.
// Create a Graph out of GraphDef. This is required currently because the
// helpers around clustering, encapsulation etc work on graphs.
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item.graph.library());
Graph graph(function_library);
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
shape_refiner.set_require_shape_inference_fns(false);
shape_refiner.set_disable_constant_propagation(true);
ImportGraphDefOptions options;
// Graph optimization happens at the late stage of graph execution, when
// colocation constraints are already validated previously and the device
// placement of nodes has also completed, so there is no need to validate
// colocation constraints again.
options.validate_colocation_constraints = false;
options.validate_shape = false;
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
OrderedNodeSet compilation_candidates;
for (Node* node : graph.op_nodes()) {
// If there is a _XlaCompile annotation, ignore the node if it is
// true. Nodes are marked with this attr via experimental_jit_scope, and
// will be handled by the mark_for_compilation pass.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
if (status.ok() && compile) {
continue;
}
// If there is already a _XlaCluster annotation, ignore the node. Nodes are
// marked with this attr to indicate they are already part of a cluster and
// hence ignored.
status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
if (status.ok()) {
continue;
}
// If there is an explicit XLA device placement, ignore the node.
DeviceType device_type("");
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
if (device_type.type_string().find("XLA") != string::npos) continue;
// Assume all fusible ops are registered.
// TODO(hpucha): Check for registration if possible.
if (!IsXlaFusible(node->def())) {
continue;
}
// XLA does not offer guaranteed aliasing between the input and output of
// the XLA cluster so it can't implement the forward-tensor-ref semantic.
// Leave such nodes out of XLA clusters.
if (HasForwardedRefInput(*node)) {
continue;
}
compilation_candidates.insert(node);
}
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
*output = item.graph;
return Status::OK();
}
GraphCycles cycles;
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(&graph, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
// TODO(hpucha): Make clustering more robust. There are two known issues that
// we need to mitigate: (a) Non-resource variables can cause deadlocks
// when clustering changes order of execution. See b/77263461 for a specific
// example. (b) Queue operations can also cause deadlocks. See b/77261498 for
// example.
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
int representative = -1;
};
// Each compilation candidate belongs to a cluster. The cluster's
// representative names the node in the 'cycles' graph that represents the
// cluster.
std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) {
Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]);
}
std::unique_ptr<DeadnessAnalysis> deadness_analysis;
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness_analysis));
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle. This is a simplified
// version of the clustering in mark_for_compilation_pass that also deals with
// nodes that are explicitly tagged to be compiled/clustered.
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
Node* node_from = graph.FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
for (int to : cycles.Successors(from)) {
if (to >= graph.num_node_ids()) {
// Node is a "frame" node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
Node* node_to = graph.FindNodeId(to);
if (compilation_candidates.find(node_to) ==
compilation_candidates.cend()) {
continue;
}
// Do not cluster across devices.
if (node_from->def().device() != node_to->def().device()) {
VLOG(2) << "Devices " << node_from->def().device() << " "
<< node_to->def().device();
VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
<< node_to->assigned_device_name();
continue;
}
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
continue;
}
TF_ASSIGN_OR_RETURN(
DeadnessAnalysis::DeadnessPredicate pred_from,
deadness_analysis->GetPredicateFor(node_from, Graph::kControlSlot));
TF_ASSIGN_OR_RETURN(
DeadnessAnalysis::DeadnessPredicate pred_to,
deadness_analysis->GetPredicateFor(node_to, Graph::kControlSlot));
if (pred_from != pred_to) {
continue;
}
// If contracting the edge would create a cycle, bail out.
// However, just because we can't merge the clusters now does not mean
// we won't be able to merge them in the future.
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
if (!cycles.ContractEdge(from, to)) continue;
// Merge the clusters. ContractEdge uses 'from' as the number of the
// merged node, so make sure 'from' is the chosen representative.
clusters[from].Merge(&clusters[to]);
worklist.push_back(&clusters[from]);
break;
}
}
// Count the number of non-trivial elements in each cluster.
std::vector<int> effective_cluster_sizes(graph.num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Identity nodes will be removed if the node gets marked for compilation.
// Therefore we don't want to count them towards the effective cluster size.
if (n->def().op() != "Identity") {
effective_cluster_sizes[cluster]++;
}
}
const int min_cluster_size = 2;
int num_clusters = 0;
for (auto size : effective_cluster_sizes) {
if (size >= min_cluster_size) {
VLOG(3) << "Cluster " << num_clusters << " " << size;
num_clusters++;
}
}
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
// Sequence number generator to ensure clusters have unique names.
static std::atomic<int64> cluster_sequence_num;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Compile if this is a cluster of >= min_cluster_size compilable operators.
if (effective_cluster_sizes[cluster] >= min_cluster_size) {
string& name = cluster_names[cluster];
if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
}
}
graph.ToGraphDef(output);
return Status::OK();
}
REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
} // namespace tensorflow

View File

@ -1,49 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
// Optimizes graphs by fusing ops where possible, resulting in more efficient
// execution.
class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
public:
XlaFusionOptimizer() {}
~XlaFusionOptimizer() override {}
Status Init(
const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
return Status::OK();
}
string name() const override { return "xla-fusion"; };
Status Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) override;
void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
const GraphDef& optimize_output, double result) override {
// Nothing to do for XlaFusionOptimizer.
}
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_

View File

@ -1,208 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
class XlaFusionOptimizerTest : public grappler::GrapplerTest {
protected:
std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
std::unordered_map<string, string> ids;
for (const NodeDef& node : graph.node()) {
string cluster;
if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
CHECK(!cluster.empty());
ids[node.name()] = cluster;
}
}
return ids;
}
};
TEST_F(XlaFusionOptimizerTest, Chains) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
Node* d =
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_EQ(clusters["E"], clusters["F"]);
EXPECT_NE(clusters["B"], clusters["E"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, FusibleOps) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(2, clusters.size());
EXPECT_EQ(clusters["C"], clusters["E"]);
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp(
"Add", a, b,
builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
ops::UnaryOp("Cos", e,
builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b =
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(3, clusters.size());
EXPECT_EQ(clusters["A"], clusters["B"]);
EXPECT_EQ(clusters["A"], clusters["C"]);
}
TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
Scope root = Scope::NewRootScope().ExitOnError();
Output var_handle =
ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
Output begin = ops::Const(root.WithOpName("begin"), 0);
Output end = ops::Const(root.WithOpName("end"), 1);
Output strides = ops::Const(root.WithOpName("strides"), 1);
ops::ResourceStridedSliceAssign assign_1(
root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
ops::ResourceStridedSliceAssign assign_2(
root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
root.graph()->AddControlEdge(assign_1.operation.node(),
assign_2.operation.node());
grappler::GrapplerItem item;
root.graph()->ToGraphDef(&item.graph);
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
}
} // namespace
} // namespace tensorflow

View File

@ -55,10 +55,32 @@ static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
return Status::OK();
}
int device_count = platform.ValueOrDie()->VisibleDeviceCount();
if (device_count <= 0) {
return Status::OK();
}
for (int i = 0; i < device_count; ++i) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
}
return Status::OK();
}
Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -32,10 +32,19 @@ constexpr std::array<DataType, 10> kExecAllTypes = {
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaInterpreterDeviceFactory::ListPhysicalDevices(
std::vector<string>* devices) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0"));
return Status::OK();
}
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -23,6 +23,7 @@ import itertools
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
@ -1041,6 +1042,62 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([2], dtype=np.int64),
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
def testBatchMatMulBroadcast(self):
"""Tests broadcasting behavior of BatchMatMul."""
with compat.forward_compatibility_horizon(2019, 4, 26):
# [2, 3] @ [1, 3, 4] -> [1, 2, 4]
self._testBinary(
math_ops.matmul,
np.array([[10, 20, 30], [11, 21, 31]], dtype=np.float32),
np.array([[[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]]],
dtype=np.float32),
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
dtype=np.float32))
# [1, 2, 3] @ [3, 4] -> [1, 2, 4]
self._testBinary(
math_ops.matmul,
np.array([[[10, 20, 30], [11, 21, 31]]], dtype=np.float32),
np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]],
dtype=np.float32),
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
dtype=np.float32))
# [2, 1, 3] @ [3, 1] -> [2, 1, 1]
self._testBinary(
math_ops.matmul,
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
np.array([[1], [2], [3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 1, 3] @ [1, 3] -> [2, 1, 1] (adjoint_b)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
np.array([[1, 2, 3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 3, 1] @ [3, 1] -> [2, 1, 1] (adjoint_a)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_a=True),
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
np.array([[1], [2], [3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 3, 1] @ [1, 3] -> [2, 1, 1] (adjoint_a and adjoint_b)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_a=True, adjoint_b=True),
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
np.array([[1, 2, 3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [5, 1, 2, 3] @ [1, 7, 3, 4] -> [5, 7, 2, 4]
self._testBinary(
math_ops.matmul,
np.ones([5, 1, 2, 3], dtype=np.float32),
np.ones([1, 7, 3, 4], dtype=np.float32),
expected=np.full([5, 7, 2, 4], 3, dtype=np.float32))
# [4, 5, 1, 2, 3] @ [1, 1, 3, 5] -> [4, 5, 1, 2, 5]
self._testBinary(
math_ops.matmul,
np.full([4, 5, 1, 2, 3], 2., dtype=np.float32),
np.full([1, 1, 3, 5], 3., dtype=np.float32),
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
def testPad(self):
for dtype, pad_type in itertools.product(
self.numeric_types, [np.int32, np.int64]):

View File

@ -113,12 +113,6 @@ class DenseLayerTest(test.TestCase):
def testDenseLayerJitScopeUndefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
pairs.
"""
with self.cached_session() as sess:
@ -136,7 +130,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(2, self.countXlaOps(labels))
self.assertEqual(1, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))

View File

@ -454,7 +454,7 @@ class PoolGradTest(xla_test.XLATestCase):
"""Verifies the output values of the pooling function.
Args:
pool_func: Pooling function to be called, e.g., tf.nn.max_pool
pool_func: Pooling function to be called, e.g., tf.nn.max_pool2d
pool_grad_func: Corresponding pooling gradient function.
input_sizes: Input tensor dimensions.
ksize: The kernel size dimensions

View File

@ -72,6 +72,30 @@ class RandomOpsTest(xla_test.XLATestCase):
for dtype in self._random_types() & self.float_types:
self._testRngIsNotConstant(rng, dtype)
def testRandomNormalMean(self):
for dtype in self._random_types() & self.float_types:
with self.cached_session():
with self.test_scope():
normal = random_ops.random_normal([1024],
dtype=dtype,
mean=1.4,
stddev=1.2)
mean = math_ops.reduce_mean(normal)
x = self.evaluate(mean)
self.assertAllClose(x, 1.4, rtol=1e-1, atol=1e-1)
def testRandomNormalVariance(self):
for dtype in self._random_types() & self.float_types:
with self.cached_session():
with self.test_scope():
normal = random_ops.random_normal([1024],
dtype=dtype,
mean=2.3,
stddev=2.0)
variance = math_ops.reduce_variance(normal)
x = self.evaluate(variance)
self.assertAllClose(x, 4.0, rtol=1e-1, atol=1e-1)
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
# TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is

View File

@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase):
def fn():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
return ta.write(-1, np.int32(7)).flow
return ta.write(-1, constant_op.constant(7)).flow
# Test writing the wrong datatype.
with self.assertRaisesOpError(
"TensorArray dtype is float but op has dtype int32"):
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
# callers provide proper init dtype.
with self.assertRaisesRegexp(
(ValueError, errors.InvalidArgumentError),
r"("
r"conversion requested dtype float32 for Tensor with dtype int32"
r"|"
r"TensorArray dtype is float but op has dtype int32"
r")"):
xla.compile(fn)[0].eval()
@test_util.disable_control_flow_v2("b/124334096 verify dtype")

View File

@ -125,7 +125,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(e0, 2.0)
l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(e1, 1.0)
self.assertAllEqual(list_ops.tensor_list_length(l), 0)
self.assertAllEqual(list_ops.tensor_list_length(l), 2)
def testGetSet(self):
with self.cached_session(), self.test_scope():
@ -211,6 +211,18 @@ class ListOpsTest(xla_test.XLATestCase):
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [0., 0., 0.])
def testZerosLikeForTensorList(self):
with self.cached_session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32,
element_shape=[],
max_num_elements=2)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
z = array_ops.zeros_like(l)
z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32)
self.assertAllEqual(z.shape.as_list(), [None])
self.assertAllEqual(z, [0.0, 0.0])
if __name__ == "__main__":
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' +
os.environ.get('TF_XLA_FLAGS', ''))

View File

@ -23,9 +23,13 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
"//tensorflow/core:platform/default/build_config.bzl",
"tf_additional_all_protos",
"tf_proto_library",
)
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
# Google-internal targets go here (must be at the end).
tf_cuda_cc_test(
name = "tensorrt_test_cc",
@ -74,13 +78,67 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "trt_engine_resource_op_kernels",
srcs = ["kernels/trt_engine_resource_ops.cc"],
copts = tf_copts(),
visibility = ["//visibility:private"],
deps = [
":trt_allocator",
":trt_engine_instance_proto_cc",
":trt_logging",
":trt_plugins",
":trt_resources",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]) + tf_custom_op_library_additional_deps(),
alwayslink = 1,
)
tf_cuda_cc_test(
name = "trt_engine_resource_ops_test",
size = "small",
srcs = ["kernels/trt_engine_resource_ops_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":trt_engine_instance_proto_cc",
":trt_engine_resource_op_kernels",
":trt_engine_resource_ops_op_lib",
":trt_logging",
":trt_resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:resource_variable_ops",
"@com_google_absl//absl/memory",
],
)
tf_cc_shared_object(
name = "python/ops/libtftrt.so",
copts = tf_copts(is_external = True),
linkopts = ["-lm"],
deps = [
":trt_op_kernels",
":trt_engine_resource_op_kernels",
":trt_op_libs",
":trt_engine_resource_ops_op_lib",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
@ -112,10 +170,40 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "trt_engine_op_test",
size = "small",
srcs = ["kernels/trt_engine_op_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":trt_op_kernels",
":trt_op_libs",
":trt_resources",
"@com_google_googletest//:gtest",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
] + if_tensorrt([
"@local_config_cuda//cuda:cuda_headers",
]),
)
tf_gen_op_libs(
op_lib_names = [
"trt_engine_op",
"get_serialized_resource_op",
"trt_engine_resource_ops",
],
)
@ -142,6 +230,7 @@ tf_cuda_library(
tf_gen_op_wrapper_py(
name = "trt_ops",
deps = [
":trt_engine_resource_ops_op_lib",
":trt_op_libs",
],
)
@ -156,7 +245,9 @@ tf_custom_op_py_library(
]),
kernels = [
":trt_op_kernels",
":trt_engine_resource_op_kernels",
":trt_op_libs",
":trt_engine_resource_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
@ -173,6 +264,7 @@ tf_cuda_library(
name = "trt_resources",
srcs = [
"utils/trt_int8_calibrator.cc",
"utils/trt_lru_cache.cc",
"utils/trt_resources.cc",
],
hdrs = [
@ -440,6 +532,13 @@ cc_library(
],
)
tf_proto_library(
name = "trt_engine_instance_proto",
srcs = ["utils/trt_engine_instance.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
)
cc_library(
name = "py_utils",
srcs = ["utils/py_utils.cc"],

View File

@ -1351,11 +1351,19 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input,
// the dims are unknown or need to be inferred. And we don't do further checks
// but rely on the caller to not make mistakes.
// Otherwise we do simple check to make sure the total sizes are the same.
if (AreDimsStaticWithDifferentSize(input_dims, dims, input.is_tensor())) {
// If an input is a weight, it is going to become a tensor via
// CreateConstantLayer. So we can treat it as a tensor for
// AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
return errors::InvalidArgument(
"Incompatible shapes: ", DebugString(input_dims), " vs. ",
DebugString(dims));
}
// ConstantLayer requires static shapes (cannot infer -1).
if (input.is_weights() && !HasStaticShape(dims)) {
return errors::InvalidArgument("Shape is not fully defined: ",
DebugString(dims));
}
if (validation_only) {
*tensor = nullptr;
return Status::OK();
@ -1590,18 +1598,6 @@ Status AllowDataTypes(const OpConverterParams& params,
return Status::OK();
}
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
const TRT_ShapedWeights& weights_src) {
TRT_ShapedWeights weights =
store->GetTempWeights(nvinfer1::DataType::kHALF, weights_src.shape_);
const float* src = static_cast<const float*>(weights_src.GetValues());
Eigen::half* dst = static_cast<Eigen::half*>(weights.GetValues());
for (int64_t i = 0; i < weights_src.count(); i++) {
dst[i] = Eigen::half_impl::float_to_half_rtne(src[i]);
}
return weights;
}
// ****************************************************************************
// Constant folding functions for weights.
// TODO(laigd): we should probably use eigen directly.
@ -1773,10 +1769,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
params->converter->TransposeTensor(tensor, permutation, &tensor));
}
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
// Prepare weights
TRT_ShapedWeights shift_weights(weights.TrtDType());
TRT_ShapedWeights scale_weights(weights.TrtDType());
@ -1938,9 +1930,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
// num_groups will be 1.
const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck);
}
// For conv, TF weights are RSCK, and TRT expects KCRS.
// For backprop, TF weights are RSKC, and TRT expects CKRS.
// Therefore, this reorder will work for both cases.
@ -3039,9 +3028,6 @@ Status ConvertBiasAdd(OpConverterParams* params) {
}
TRT_ShapedWeights weights = inputs.at(1).weights();
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
if (weights.shape_.d[0] == 1) {
mode = nvinfer1::ScaleMode::kUNIFORM;
@ -4238,6 +4224,95 @@ Status ConvertTopK(OpConverterParams* params) {
return Status::OK();
}
Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
TFAttrs attrs(node_def);
const int block_size = attrs.get<int64>("block_size");
if (block_size < 2) {
return errors::InvalidArgument("Block size must be 2 or greater, at ",
node_def.name());
}
const string data_format = attrs.get<string>("data_format");
if (data_format != "NCHW" && data_format != "NHWC") {
return errors::Unimplemented("Data format ", data_format,
" is not supported, at ", node_def.name());
}
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
if (dims.nbDims != 3) {
return errors::InvalidArgument("The input to ", node_def.op(),
" must be rank 4, at ", node_def.name());
}
const int num_channels = data_format == "NCHW" ? dims.d[0] : dims.d[2];
const int h = data_format == "NCHW" ? dims.d[1] : dims.d[0];
const int w = data_format == "NCHW" ? dims.d[2] : dims.d[1];
// Get shuffle parameters.
nvinfer1::Dims first_shuffle_shape;
nvinfer1::Permutation transpose_perm;
nvinfer1::Dims second_shuffle_shape;
if (node_def.op() == "DepthToSpace") {
if (num_channels % (block_size * block_size) != 0) {
return errors::InvalidArgument(
"Number of channels must be divisible by block_size*block_size, at ",
node_def.name());
}
// First Reshape [C, H, W] - > [r, r, C/(r*r), H, W]
first_shuffle_shape = {
/*nbDims=*/5,
/*d=*/{block_size, block_size, num_channels / (block_size * block_size),
h, w}};
// Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r]
transpose_perm = {2, 3, 0, 4, 1};
// Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r]
second_shuffle_shape =
nvinfer1::DimsCHW(num_channels / (block_size * block_size),
h * block_size, w * block_size);
} else if (node_def.op() == "SpaceToDepth") {
if (h % block_size != 0 || w % block_size != 0) {
return errors::InvalidArgument(
"Width and height must be divisible by block_size, at ",
node_def.name());
}
// First Reshape [C, H, W] -> [C, H/r, r, W/r, r]
first_shuffle_shape = {/*nbDims=*/5,
/*d=*/{num_channels, h / block_size, block_size,
w / block_size, block_size}};
// Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r]
transpose_perm = {2, 4, 0, 1, 3};
// Second Reshape [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r]
second_shuffle_shape = nvinfer1::DimsCHW(
num_channels * block_size * block_size, h / block_size, w / block_size);
}
if (params->validation_only) return Status::OK();
nvinfer1::IShuffleLayer* first_shuffle =
params->converter->network()->addShuffle(*inputs.at(0).tensor());
TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name());
if (data_format == "NHWC") {
first_shuffle->setFirstTranspose({2, 0, 1});
}
first_shuffle->setReshapeDimensions(first_shuffle_shape);
first_shuffle->setSecondTranspose(transpose_perm);
nvinfer1::IShuffleLayer* second_shuffle =
params->converter->network()->addShuffle(*first_shuffle->getOutput(0));
TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name());
second_shuffle->setReshapeDimensions(second_shuffle_shape);
if (data_format == "NHWC") {
second_shuffle->setSecondTranspose({1, 2, 0});
}
params->converter->MarkQuantizationRangesAsInferrable(
inputs.at(0).tensor(), first_shuffle->getOutput(0));
params->converter->MarkQuantizationRangesAsInferrable(
first_shuffle->getOutput(0), second_shuffle->getOutput(0));
params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0)));
return Status::OK();
}
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
Status ConvertCombinedNMS(OpConverterParams* params) {
TF_RETURN_IF_ERROR(
@ -4417,6 +4492,7 @@ static void RegisterValidatableOpConverters(
(*registration)["Const"] = ConvertConst;
(*registration)["Conv2D"] = ConvertConv2D;
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
(*registration)["DepthToSpace"] = ConvertDepthSpaceShuffle;
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["ExpandDims"] = ConvertExpandDims;
(*registration)["GatherV2"] = ConvertGather;
@ -4431,6 +4507,7 @@ static void RegisterValidatableOpConverters(
(*registration)["Slice"] = ConvertSlice;
(*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed
(*registration)["Softmax"] = ConvertSoftmax;
(*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
(*registration)["Split"] = ConvertSplit;
(*registration)["Square"] = ConvertSquare;
(*registration)["Squeeze"] = ConvertSqueeze;

View File

@ -212,6 +212,19 @@ std::vector<CType> InitTestVector(int size, CType start_value = CType(0)) {
return res;
}
template <typename InCType, typename OutCType>
struct StaticCaster {
OutCType operator()(InCType in) const { return static_cast<OutCType>(in); }
};
template <typename InCType, typename OutCType>
std::vector<OutCType> CastTestVector(const std::vector<InCType>& vals) {
std::vector<OutCType> res(vals.size());
std::transform(vals.begin(), vals.end(), res.begin(),
StaticCaster<InCType, OutCType>());
return res;
}
// Fake ITensor implementation for testing purposes.
class FakeITensor : public nvinfer1::ITensor {
public:
@ -721,19 +734,25 @@ TEST_F(ConverterTest, TransposeTensor) {
ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
}
void TestPrepareTensorForShape_Tensor(
const std::vector<int>& tensor_dims, const std::vector<int>& reshape_dims,
const std::vector<int>& expected_tensor_dims, Converter* converter,
void TestPrepareTensorForShape(
const std::vector<int>& input_dims, const std::vector<int>& reshape_dims,
const std::vector<int>& expected_tensor_dims, bool input_is_tensor,
Converter* converter, TrtWeightStore* weight_store,
error::Code expected_code = error::OK,
const char* expected_error_msg_substr = nullptr) {
nvinfer1::ITensor* input_tensor = converter->network()->addInput(
"", nvinfer1::DataType::kFLOAT, GetTestDims(tensor_dims));
TRT_TensorOrWeights input;
if (input_is_tensor) {
input = TRT_TensorOrWeights(converter->network()->addInput(
"", nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
} else {
input = TRT_TensorOrWeights(weight_store->GetTempWeights(
nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
}
nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) {
const Status status = converter->PrepareTensorForShape(
TRT_TensorOrWeights(input_tensor), GetTestDims(reshape_dims),
validation_only, &output_tensor);
input, GetTestDims(reshape_dims), validation_only, &output_tensor);
if (expected_code == error::OK) {
TF_EXPECT_OK(status);
if (validation_only) {
@ -748,49 +767,45 @@ void TestPrepareTensorForShape_Tensor(
}
}
TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
TEST_F(ConverterTest, PrepareTensorForShape) {
for (bool input_is_tensor : {true, false}) {
// Shape size doesn't match.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {2, 3, 6}, {}, converter_.get(),
error::INVALID_ARGUMENT,
"Incompatible shapes");
TestPrepareTensorForShape({2, 3, 5}, {2, 3, 6}, {}, input_is_tensor,
converter_.get(), weight_store_,
error::INVALID_ARGUMENT, "Incompatible shapes");
// Regular shape.
Reset();
TestPrepareTensorForShape({2, 3, 5}, {10, 3}, {10, 3}, input_is_tensor,
converter_.get(), weight_store_);
// Reshape to zero rank.
Reset();
TestPrepareTensorForShape({1, 1}, {}, {}, input_is_tensor, converter_.get(),
weight_store_);
}
// Tensor input with zero rank.
Reset();
TestPrepareTensorForShape({}, {1, 1}, {1, 1}, /*input_is_tensor=*/true,
converter_.get(), weight_store_);
// TODO(aaroey): we should check the case where uninferred dimensions are
// not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
// Infer shape, ok.
// Infer tensor shape, ok.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {-1, 2}, {15, 2},
converter_.get());
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
/*input_is_tensor=*/true, converter_.get(),
weight_store_);
// Regular shape.
// Infer weight shape, should fail.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {10, 3}, {10, 3},
converter_.get());
// Input with zero rank.
Reset();
TestPrepareTensorForShape_Tensor({}, {1, 1}, {1, 1}, converter_.get());
// Reshape to zero rank.
Reset();
TestPrepareTensorForShape_Tensor({1, 1}, {}, {}, converter_.get());
}
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
TRT_ShapedWeights weights = weight_store_->GetTempWeights(
nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) {
TF_EXPECT_OK(converter_->PrepareTensorForShape(
TRT_TensorOrWeights(weights), GetTestDims({10, 3}), validation_only,
&output_tensor));
if (validation_only) {
EXPECT_EQ(nullptr, output_tensor);
} else {
ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
}
}
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
/*input_is_tensor=*/false, converter_.get(),
weight_store_, error::INVALID_ARGUMENT,
"Shape is not fully defined");
}
TEST_F(ConverterTest, MaybeUpdateBatchSize) {
@ -4910,6 +4925,279 @@ TEST_F(OpConverterTest, ConvertArgMinMax) {
// TestConvertArgMinMax<ops::ArgMax, DT_INT32>(this);
}
// Get the NodeDef for DepthToSpace or SpaceToSpace.
template <typename OpType>
NodeDef GetDepthSpaceShuffleNodeDef(DataType dtype, int block_size,
string data_format) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), dtype);
auto attrs = OpType::DataFormat(data_format);
auto shuffle = OpType(s.WithOpName("my_shuffle"), input, block_size, attrs);
return shuffle.operation.node()->def();
}
template <typename CType>
struct DepthSpaceShuffleTestParams {
std::vector<int> input_dims;
std::vector<CType> input_value;
int block_size;
string data_format;
std::vector<int> expected_output_dims;
std::vector<CType> expected_output;
};
template <typename OpType, DataType dtype, typename CType>
void TestConvertDepthSpaceShuffle(
OpConverterTest* test,
const std::vector<DepthSpaceShuffleTestParams<CType>>& params) {
for (int i = 0; i < params.size(); ++i) {
test->Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
dtype, params[i].block_size, params[i].data_format);
test->AddTestTensor("input", params[i].input_dims, 1,
TfDataTypeToTrt(dtype));
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_shuffle", &output));
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
output.tensor()->getDimensions());
DataVec input_data{{"input", test::AsTensor<CType>(params[i].input_value)}};
DataVec output_data{{"my_shuffle", ConstructTensor<CType>(
params[i].expected_output.size())}};
test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(params[i].expected_output));
}
}
template <DataType dtype>
void TestConvertDepthToSpace(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const std::vector<CType> common_input = InitTestVector<CType>(16);
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
{
/*input_shape=*/{4, 2, 2},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{1, 4, 4},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}),
},
{
/*input_shape=*/{2, 2, 4},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{4, 4, 1},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
},
{
/*input_shape=*/{16, 1, 1},
/*input_value=*/common_input,
/*block_size=*/4,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{1, 4, 4},
/*expected_output=*/InitTestVector<CType>(16),
},
{
/*input_shape=*/{2, 2, 8},
/*input_value=*/InitTestVector<CType>(32),
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{4, 4, 2},
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
9, 10, 11, 4, 5,
6, 7, 12, 13, 14,
15, 16, 17, 18, 19,
24, 25, 26, 27, 20,
21, 22, 23, 28, 29,
30, 31}),
},
};
TestConvertDepthSpaceShuffle<ops::DepthToSpace, dtype, CType>(test, params);
}
TEST_F(OpConverterTest, ConvertDepthToSpace) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_shuffle", "DepthToSpace", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"DepthToSpace got 0 inputs but expected 1, at my_shuffle");
}
{
// Input is a weight, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"input\" for DepthToSpace must be a "
"tensor, at my_shuffle");
}
{
// Input rank != 4
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
AddTestTensor("input", {16, 32});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"The input to DepthToSpace must be rank 4, at my_shuffle");
}
{
// Channels not divisible by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Number of channels must be divisible by "
"block_size*block_size, at my_shuffle");
}
{
// Unsupported format, should fail.
Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
DT_FLOAT, 2, "NCHW_VECT_C");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Data format NCHW_VECT_C is not supported, at my_shuffle");
}
TestConvertDepthToSpace<DT_FLOAT>(this);
TestConvertDepthToSpace<DT_HALF>(this);
TestConvertDepthToSpace<DT_INT32>(this);
}
template <DataType dtype>
void TestConvertSpaceToDepth(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const std::vector<CType> common_input = InitTestVector<CType>(16);
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
{
/*input_shape=*/{1, 4, 4},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{4, 2, 2},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}),
},
{
/*input_shape=*/{4, 4, 1},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{2, 2, 4},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
},
{
/*input_shape=*/{1, 4, 4},
/*input_value=*/common_input,
/*block_size=*/4,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{16, 1, 1},
/*expected_output=*/InitTestVector<CType>(16),
},
{
/*input_shape=*/{4, 4, 2},
/*input_value=*/InitTestVector<CType>(32),
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{2, 2, 8},
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
9, 10, 11, 4, 5,
6, 7, 12, 13, 14,
15, 16, 17, 18, 19,
24, 25, 26, 27, 20,
21, 22, 23, 28, 29,
30, 31}),
},
};
TestConvertDepthSpaceShuffle<ops::SpaceToDepth, dtype, CType>(test, params);
}
TEST_F(OpConverterTest, ConvertSpaceToDepth) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_shuffle", "SpaceToDepth", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"SpaceToDepth got 0 inputs but expected 1, at my_shuffle");
}
{
// Input is a weight, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"input\" for SpaceToDepth must be a "
"tensor, at my_shuffle");
}
{
// Input rank != 4
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
AddTestTensor("input", {16, 32});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"The input to SpaceToDepth must be rank 4, at my_shuffle");
}
{
// Width not divisble by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 9, 32});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Width and height must be divisible by "
"block_size, at my_shuffle");
}
{
// Height not divisble by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 32, 9});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Width and height must be divisible by "
"block_size, at my_shuffle");
}
{
// Unsupported format, should fail.
Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(
DT_FLOAT, 2, "NCHW_VECT_C");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Data format NCHW_VECT_C is not supported, at my_shuffle");
}
TestConvertSpaceToDepth<DT_FLOAT>(this);
TestConvertSpaceToDepth<DT_HALF>(this);
TestConvertSpaceToDepth<DT_INT32>(this);
}
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
@ -290,17 +291,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
VLOG(1) << "Executing TRT calibration: " << name();
helper->Ref();
core::ScopedUnref sc(helper);
auto res_mgr = ctx->resource_manager();
TRTCalibrationResource* calib_res = nullptr;
OP_REQUIRES_OK(ctx,
res_mgr->LookupOrCreate(
"TF_TRT_Calibration", name(),
ctx->resource_manager()->LookupOrCreate(
"TF-TRT-Calibration", name(),
reinterpret_cast<SerializableResourceBase**>(&calib_res),
{[ctx, this](SerializableResourceBase** cr) -> Status {
return this->AllocateCalibrationResources(ctx, cr);
}}));
core::ScopedUnref calib_sc(calib_res);
int num_inputs = ctx->num_inputs();
// TODO(laigd): need to check that input shape matches.
// Pass input data to calibrator
std::unordered_map<string, void*> input_data;
for (int i = 0; i < num_inputs; i++) {
@ -425,8 +426,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
const_cast<float*>(input_tensor.flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(ERROR) << "FP16 inputs are not supported yet!";
return kRetry;
buffers[binding_index] =
const_cast<Eigen::half*>(input_tensor.flat<Eigen::half>().data());
break;
case nvinfer1::DataType::kINT8:
LOG(ERROR) << "INT8 inputs are not supported yet!";
return kRetry;
@ -480,8 +482,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
const_cast<float*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(WARNING) << "half size is not supported yet!";
return kRetry;
buffers[binding_index] =
const_cast<Eigen::half*>(output_tensor->flat<Eigen::half>().data());
break;
case nvinfer1::DataType::kINT8:
LOG(WARNING) << "int8 is not supported yet!";
return kRetry;
@ -522,10 +525,22 @@ EngineContext* TRTEngineOp::GetEngine(
// TODO(tmorris): using first input to get batch size - is this reliable?
const int batch_size = input_shapes[0].dim_size(0);
// Get engine cache
// Canonicalize the op name by removing the scopes if any. This is mainly
// because in TFv2, the function graph can be instantiated in various ways and
// it'll insert scope names to the name of the TRTEngineOps, which will result
// in many different engine caches if we use the instantiated op name
// directly, but we still want all of them share the same cache (if they were
// representing the same subgraph).
absl::string_view resource_name = name();
size_t last_slash = resource_name.find_last_of('/');
if (last_slash != absl::string_view::npos) {
resource_name.remove_prefix(last_slash + 1);
}
// Get engine cache.
TRTEngineCacheResource* cache_res = nullptr;
auto status = ctx->resource_manager()->LookupOrCreate(
"TRTEngineCache", name(), &cache_res,
"TF-TRT-Engine-Cache", string(resource_name), &cache_res,
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
return Status::OK();
@ -632,12 +647,13 @@ EngineContext* TRTEngineOp::GetEngine(
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
return &empty_context;
}
VLOG(1) << "Conversion is done";
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
cache.emplace(engine_input_shapes,
absl::make_unique<EngineContext>(std::move(engine),
std::move(exec_context)));
VLOG(1) << "Added new engine to cache of " << name()
<< ". Cache size: " << cache.size();
}
return cache.at(engine_input_shapes).get();
}

View File

@ -0,0 +1,106 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <dirent.h>
#include <string.h>
#include <fstream>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
namespace tensorrt {
using ::testing::ElementsAre;
template <typename T>
class TRTEngineOpTest : public OpsTestBase {};
using TypeList = ::testing::Types<float, Eigen::half>;
TYPED_TEST_SUITE(TRTEngineOpTest, TypeList);
TYPED_TEST(TRTEngineOpTest, Basic) {
DataType dtype = DataTypeToEnum<TypeParam>::v();
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
// Create simple TF graph.
Scope s = Scope::NewRootScope();
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
ops::Placeholder::Shape({1, 2}));
auto add = ops::Add(s.WithOpName("add"), feed, feed);
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
GraphDef graph_def;
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
TensorShapeProto shape;
TensorShape({1, 2}).AsProto(&shape);
// Create the op.
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
TF_ASSERT_OK(NodeDefBuilder("op", "TRTEngineOp")
.Input(FakeInput(1, dtype))
.Attr("input_shapes", {shape})
.Attr("output_shapes", {shape})
.Attr("static_engine", false)
.Attr("segment_funcdef_name", "") // no native fallback
.Attr("serialized_segment", graph_def.SerializeAsString())
.Attr("calibration_data", "")
.Attr("max_cached_engines_count", 1)
.Attr("workspace_size_bytes", 1 << 20)
.Attr("precision_mode", "FP32")
.Attr("use_calibration", false)
.Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(OpsTestBase::InitOp());
// Execute the op.
OpsTestBase::AddInputFromArray<TypeParam>(TensorShape({1, 2}),
{TypeParam(0.0f), TypeParam(1.0f)});
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
Tensor* output = OpsTestBase::context_->mutable_output(0);
const auto& tensor_map = output->flat<TypeParam>();
std::vector<TypeParam> output_data(tensor_map.size());
ASSERT_EQ(0, cudaDeviceSynchronize());
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
sizeof(TypeParam) * tensor_map.size(),
cudaMemcpyDeviceToHost));
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,223 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/logging.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
using ::nvinfer1::IRuntime;
class CreateTRTEngineCache : public OpKernel {
public:
explicit CreateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
}
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "Creating TRT engine cache resource in container " << container_
<< " for op " << resource_name_ << " on device "
<< ctx->device()->name();
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->Create(
container_, resource_name_,
new TRTEngineCacheResource(ctx, max_cached_engines_)));
Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() =
MakeResourceHandle<TRTEngineCacheResource>(ctx, container_,
resource_name_);
}
private:
string container_;
string resource_name_;
// Maximum number of cached engines
int max_cached_engines_;
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCache);
};
REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCache")
.Device(DEVICE_GPU)
.HostMemory("engine_cache_handle"),
CreateTRTEngineCache);
class PopulateTRTEngineCache : public OpKernel {
public:
explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
ResourceHandle handle = HandleFromInput(ctx, 0);
TRTEngineCacheResource* resource = nullptr;
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
core::ScopedUnref unref_me(resource);
auto allocator = resource->allocator_.get();
OP_REQUIRES(ctx, allocator != nullptr,
errors::Internal("Not able to initialize TRT engine cache when "
"GPU allocator is empty."));
OP_REQUIRES(ctx, resource->cache_.size() == 0,
errors::Internal("Expect engine cache to be empty, but got ",
resource->cache_.size(), " entries."));
// Get the file name.
const string& filename = ctx->input(1).scalar<string>()();
OP_REQUIRES(ctx, !filename.empty(),
errors::InvalidArgument("filename cannot be empty."));
// Parse the serialized engines and add them to the cache.
std::unique_ptr<RandomAccessFile> file;
OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file));
auto reader = absl::make_unique<io::RecordReader>(file.get());
uint64 offset = 0;
int num_loaded_engine = 0;
do {
string record;
Status status = reader->ReadRecord(&offset, &record);
if (errors::IsOutOfRange(status)) break;
TRTEngineInstance engine_instance;
engine_instance.ParseFromString(record);
std::vector<TensorShape> engine_input_shapes;
for (const TensorShapeProto& shape : engine_instance.input_shapes()) {
engine_input_shapes.emplace_back(shape);
}
TrtUniquePtrType<IRuntime> infer(
nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger()));
infer->setGpuAllocator(allocator);
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
infer->deserializeCudaEngine(
engine_instance.serialized_engine().c_str(),
engine_instance.serialized_engine().size(),
PluginFactoryTensorRT::GetInstance()));
auto raw_engine = engine.get();
resource->cache_.emplace(
engine_input_shapes,
absl::make_unique<EngineContext>(
std::move(engine), TrtUniquePtrType<nvinfer1::IExecutionContext>(
raw_engine->createExecutionContext())));
++num_loaded_engine;
} while (1);
VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container "
<< handle.container() << " for op " << handle.name()
<< " on device " << ctx->device()->name() << " from file "
<< filename;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache);
};
REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache")
.Device(DEVICE_GPU)
.HostMemory("engine_cache_handle"),
PopulateTRTEngineCache);
class DumpTRTEngineCache : public OpKernel {
public:
explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump",
&delete_cache_after_dump_));
}
void Compute(OpKernelContext* ctx) override {
const string& container = ctx->input(0).scalar<string>()();
const string& resource_name = ctx->input(1).scalar<string>()();
const string& filename = ctx->input(2).scalar<string>()();
OP_REQUIRES(ctx, !filename.empty(),
errors::InvalidArgument("filename cannot be empty."));
TRTEngineCacheResource* resource = nullptr;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
container, resource_name, &resource));
core::ScopedUnref unref_me(resource);
// Serialize the engines and write them to file.
std::unique_ptr<WritableFile> file;
OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file));
auto writer = absl::make_unique<io::RecordWriter>(file.get());
for (const auto& pair : resource->cache_) {
TRTEngineInstance engine_instance;
// Add input shapes.
const std::vector<TensorShape>& engine_input_shapes = pair.first;
for (const TensorShape& shape : engine_input_shapes) {
shape.AsProto(engine_instance.add_input_shapes());
}
// Add the serialized engine.
const std::unique_ptr<EngineContext>& engine = pair.second;
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(
engine->cuda_engine->serialize());
engine_instance.set_serialized_engine(engine_data->data(),
engine_data->size());
OP_REQUIRES_OK(ctx,
writer->WriteRecord(engine_instance.SerializeAsString()));
}
VLOG(1) << "Serialized " << resource->cache_.size()
<< " TRT engines in container " << container << " for op "
<< resource_name << " on device " << ctx->device()->name()
<< " to file " << filename;
if (delete_cache_after_dump_) {
VLOG(1) << "Destroying TRT engine cache resource in container "
<< container << " for op " << resource_name << " on device "
<< ctx->device()->name();
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->Delete<TRTEngineCacheResource>(
container, resource_name));
}
}
private:
bool delete_cache_after_dump_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache);
};
REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU),
DumpTRTEngineCache);
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,205 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <dirent.h>
#include <string.h>
#include <fstream>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
namespace tensorflow {
namespace tensorrt {
class TRTEngineResourceOpsTest : public OpsTestBase {
protected:
void Reset() {
inputs_.clear();
gtl::STLDeleteElements(&tensors_);
gtl::STLDeleteElements(&managed_outputs_);
}
TrtUniquePtrType<nvinfer1::ICudaEngine> CreateTRTEngine() {
Logger logger;
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(logger));
TrtUniquePtrType<nvinfer1::INetworkDefinition> network(
builder->createNetwork());
// Add the input.
nvinfer1::Dims dims;
dims.nbDims = 1;
dims.d[0] = 1;
nvinfer1::ITensor* input =
network->addInput("input", nvinfer1::DataType::kFLOAT, dims);
EXPECT_NE(nullptr, input);
// Add a unary layer.
nvinfer1::IUnaryLayer* layer =
network->addUnary(*input, nvinfer1::UnaryOperation::kEXP);
EXPECT_NE(nullptr, layer);
// Mark the output.
nvinfer1::ITensor* output = layer->getOutput(0);
output->setName("output");
network->markOutput(*output);
// Build the engine
builder->setMaxBatchSize(1);
builder->setMaxWorkspaceSize(1 << 10);
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
builder->buildCudaEngine(*network));
EXPECT_NE(nullptr, engine);
return engine;
}
};
TEST_F(TRTEngineResourceOpsTest, Basic) {
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
ResourceMgr* rm = device->resource_manager();
SetDevice(DEVICE_GPU, std::move(device));
// Create the resource.
const string container = "mycontainer";
const string resource_name = "myresource";
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
.Attr("container", container)
.Attr("resource_name", resource_name)
.Attr("max_cached_engines_count", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
TF_ASSERT_OK(RunOpKernel());
ResourceHandle handle =
context_->mutable_output(0)->scalar<ResourceHandle>()();
TRTEngineCacheResource* resource = nullptr;
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
// Create a serialized TRT engine file.
TrtUniquePtrType<nvinfer1::ICudaEngine> engine = CreateTRTEngine();
TrtUniquePtrType<nvinfer1::IExecutionContext> context(
engine->createExecutionContext());
resource->cache_.emplace(
std::vector<TensorShape>{TensorShape({1, 1})},
absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
resource->Unref();
// Serialize the engine using DumpTRTEngineCache op.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache")
.Attr("delete_cache_after_dump", true)
.Input(FakeInput(DT_STRING))
.Input(FakeInput(DT_STRING))
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<string>(TensorShape({}), {container});
AddInputFromArray<string>(TensorShape({}), {resource_name});
const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file");
AddInputFromArray<string>(TensorShape({}), {filename});
TF_ASSERT_OK(RunOpKernel());
// Make sure the cache is deleted.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
.Attr("ignore_lookup_error", false)
.Input(FakeInput(DT_RESOURCE))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
// Verify the serialized engine file.
Env* env = Env::Default();
std::unique_ptr<RandomAccessFile> file;
TF_ASSERT_OK(env->NewRandomAccessFile(filename, &file));
auto reader = absl::make_unique<io::RecordReader>(file.get());
uint64 offset = 0;
string record;
TF_ASSERT_OK(reader->ReadRecord(&offset, &record));
TRTEngineInstance engine_instance;
engine_instance.ParseFromString(record);
EXPECT_EQ(1, engine_instance.input_shapes_size());
EXPECT_EQ(2, engine_instance.input_shapes(0).dim_size());
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(0).size());
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(1).size());
EXPECT_TRUE(errors::IsOutOfRange(reader->ReadRecord(&offset, &record)));
// Recreate the cache resource.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
.Attr("container", container)
.Attr("resource_name", resource_name)
.Attr("max_cached_engines_count", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
TF_ASSERT_OK(RunOpKernel());
handle = context_->mutable_output(0)->scalar<ResourceHandle>()();
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
EXPECT_EQ(0, resource->cache_.size());
resource->Unref();
// Deserialize the engine using PopulateTRTEngineCache op.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache")
.Input(FakeInput(DT_RESOURCE))
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
AddInputFromArray<string>(TensorShape({}), {filename});
TF_ASSERT_OK(RunOpKernel());
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
EXPECT_EQ(1, resource->cache_.size());
resource->Unref();
// Destroy the engine cache again.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
.Attr("ignore_lookup_error", false)
.Input(FakeInput(DT_RESOURCE))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
TF_ASSERT_OK(RunOpKernel());
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,52 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
REGISTER_OP("CreateTRTEngineCache")
.Attr("container: string")
.Attr("resource_name: string")
.Attr("max_cached_engines_count: int = 1")
.Output("engine_cache_handle: resource")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("PopulateTRTEngineCache")
.Input("engine_cache_handle: resource")
.Input("filename: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("DumpTRTEngineCache")
.Attr("delete_cache_after_dump: bool = false")
.Input("container: string")
.Input("resource_name: string")
.Input("filename: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -459,7 +459,7 @@ Status SegmentGraph(const Graph* tf_graph,
}
LOG(INFO) << msg << "(For more information see "
<< "https://docs.nvidia.com/deeplearning"
<< "/dgx/integrate-tf-trt/index.html#support-ops).";
<< "/dgx/tf-trt-user-guide/index.html#supported-ops).";
// The segmentation algorithm below visits nodes in reverse topological order
// and attempts to merge nodes along output edges. That means that subgraphs

View File

@ -0,0 +1,19 @@
syntax = "proto3";
package tensorflow.tensorrt;
import "tensorflow/core/framework/tensor_shape.proto";
// Containing information for a serialized TensorRT engine.
message TRTEngineInstance {
// The input shapes of the TRT engine.
repeated TensorShapeProto input_shapes = 1;
// The serialized TRT engine.
//
// TODO(laigd): consider using a more efficient in-memory representation
// instead of string which is the default here.
bytes serialized_engine = 2;
// TODO(laigd): consider adding calibration stats, precision_modes, etc.
}

View File

@ -0,0 +1,79 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include <sstream>
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/mutex.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
Logger& TRTEngineCacheResource::GetLogger() {
static Logger* logger = new Logger();
return *logger;
}
TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
size_t capacity)
: cache_(capacity) {
auto device = ctx->device();
auto alloc = device->GetAllocator(AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
allocator_ = nullptr;
} else {
allocator_.reset(new TRTDeviceAllocator(alloc));
}
}
TRTEngineCacheResource::~TRTEngineCacheResource() {
VLOG(1) << "Destroying TRTEngineCacheResource...";
}
string TRTEngineCacheResource::DebugString() const {
std::stringstream oss;
using std::dec;
using std::endl;
using std::hex;
oss << "TRTEngineCacheResource: ";
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
oss << "LRUCache = " << hex << &cache_ << dec << endl;
oss << "Containing " << cache_.size() << " entries: " << endl;
for (const auto& item : cache_) {
mutex_lock lock(item.second->mu);
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
<< "ICudaEngine: " << item.second->cuda_engine.get() << ", "
<< "IExecutionContext: " << item.second->execution_context.get() << dec
<< endl;
}
return oss.str();
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
@ -100,17 +101,14 @@ class LRUCache {
}
// Creates n free positions in cache
Status DiscardOld(size_t n = 0) {
if (n > capacity_) {
return errors::Internal("Insufficient capacity in cache (capacity = ",
capacity_, ", requested ", n, ")");
}
void DiscardOld(size_t n = 0) {
DCHECK(capacity_ >= n) << "Insufficient capacity in cache (capacity = "
<< capacity_ << ", requested " << n << ")";
while (objects_.size() > (capacity_ - n)) {
key_type discard_key = keys_.back();
keys_.pop_back();
objects_.erase(discard_key);
}
return Status::OK();
}
};
@ -141,36 +139,18 @@ struct EngineContext {
class TRTEngineCacheResource : public ResourceBase {
public:
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity)
: cache_(capacity) {
auto device = ctx->device();
auto alloc = device->GetAllocator(AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
allocator_ = nullptr;
} else {
allocator_.reset(new TRTDeviceAllocator(alloc));
}
}
// According to the TensorRT API, the logger is considered a singleton by the
// TensorRT library, and multiple instances of IRuntime and/or IBuilder must
// all use the same logger. So here we make it a singleton.
//
// TODO(laigd): use this logger in all places where conversion happens.
static Logger& GetLogger();
string DebugString() const override {
std::stringstream oss;
using std::dec;
using std::endl;
using std::hex;
oss << "TRTEngineCacheResource: ";
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
oss << "LRUCache = " << hex << &cache_ << dec << endl;
oss << "Containing " << cache_.size() << " entries: " << endl;
for (const auto& item : cache_) {
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
<< "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", "
<< "IExecutionContext: " << item.second.get()->execution_context.get()
<< dec << endl;
}
return oss.str();
}
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity);
~TRTEngineCacheResource() override;
string DebugString() const override;
// Keep device allocator for TRT.
std::unique_ptr<TRTBaseAllocator> allocator_;

View File

@ -210,6 +210,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container",
@ -376,6 +377,7 @@ tf_cc_test(
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -390,6 +392,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@ -503,6 +506,39 @@ cc_library(
],
)
cc_library(
name = "rearrange_function_argument_pass",
srcs = [
"rearrange_function_argument_pass.cc",
],
hdrs = [
"rearrange_function_argument_pass.h",
],
deps = [
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "rearrange_function_argument_pass_registration",
srcs = [
"rearrange_function_argument_pass_registration.cc",
],
deps = [
":rearrange_function_argument_pass",
],
alwayslink = 1,
)
cc_library(
name = "functionalize_control_flow_pass_registration",
srcs = [

View File

@ -100,9 +100,12 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.name = resource->name();
break;
}
case XlaExpression::Kind::kTensorList:
return errors::Unimplemented(
"TensorList as function argument is not yet implemented.");
case XlaExpression::Kind::kTensorList: {
arg.kind = XlaCompiler::Argument::kTensorList;
const xla::XlaOp& tensor_list = expressions[i]->handle();
arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
break;
}
case XlaExpression::Kind::kInvalid:
return errors::InvalidArgument("Invalid function argument");
}
@ -301,9 +304,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
for (int64 i = 0; i < n->num_outputs(); ++i) {
if (result.outputs[i].is_constant) {
xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
} else {
if (result.outputs[i].is_tensor_list) {
xla_op_context.SetTensorListOutput(
i, xla::GetTupleElement(output_handle, computation_output));
} else {
xla_op_context.SetOutput(
i, xla::GetTupleElement(output_handle, computation_output));
}
++computation_output;
}
}

View File

@ -122,7 +122,7 @@ tf_kernel_library(
tags = ["optonly"],
deps = [
":case_op",
":conv_op_helpers",
":conv_op_attrs",
":if_op",
":tensor_list_utils",
":while_op",
@ -146,6 +146,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:conv_op_helpers",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
@ -223,25 +224,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "conv_op_helpers",
srcs = ["conv_op_helpers.cc"],
hdrs = ["conv_op_helpers.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core/kernels:conv_ops",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tensor_list_utils",
srcs = ["tensor_list_utils.cc"],
@ -384,3 +366,14 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "conv_op_attrs",
srcs = ["conv_op_attrs.cc"],
hdrs = ["conv_op_attrs.h"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client/lib:conv_op_helpers",
"//tensorflow/core:framework",
],
)

View File

@ -58,6 +58,7 @@ class AddNOp : public XlaOpKernel {
xla::XlaOp push_index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &push_index));
OP_REQUIRES_OK(ctx, BuildTensorList(sum, push_index, &sum));
ctx->SetTensorListOutput(0, sum);
break;
}
default:
@ -65,10 +66,9 @@ class AddNOp : public XlaOpKernel {
for (int i = 1; i < ctx->num_inputs(); ++i) {
sum = xla::Add(sum, ctx->Input(i));
}
}
ctx->SetOutput(0, sum);
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(AddNOp);

View File

@ -44,6 +44,7 @@ class BatchMatMulOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("BatchMatMul"), BatchMatMulOp);
REGISTER_XLA_OP(Name("BatchMatMulV2"), BatchMatMulOp);
} // namespace
} // namespace tensorflow

View File

@ -36,7 +36,9 @@ namespace {
class CategoricalOp : public XlaOpKernel {
public:
explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
explicit CategoricalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
void Compile(XlaOpKernelContext* ctx) override {
// Get the logits
@ -101,8 +103,9 @@ class CategoricalOp : public XlaOpKernel {
xla::PrimitiveType xla_output_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
xla::XlaOp argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
if (num_samples == 1) {
argmax = xla::Reshape(argmax, {batch_size, 1});
}
@ -124,6 +127,7 @@ class CategoricalOp : public XlaOpKernel {
}
private:
bool is_gpu_;
TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
};
@ -134,7 +138,8 @@ REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"),
class StatelessCategoricalOp : public CategoricalOp {
public:
explicit StatelessCategoricalOp(OpKernelConstruction* ctx)
: CategoricalOp(ctx) {
: CategoricalOp(ctx),
device_type_string_(ctx->device_type().type_string()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}
@ -150,7 +155,7 @@ class StatelessCategoricalOp : public CategoricalOp {
// * log(-log(0)) is ∞.
// * log(-log(1)) is -∞.
xla::XlaOp uniforms = StatelessRngUniform(
seed, uniform_shape,
device_type_string_, seed, uniform_shape,
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
xla::One(builder, uniform_shape.element_type()));
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);
@ -166,6 +171,7 @@ class StatelessCategoricalOp : public CategoricalOp {
private:
DataType dtype_;
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp);
};

View File

@ -0,0 +1,93 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/conv_op_attrs.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/common_shape_fns.h"
namespace tensorflow {
namespace {
// Converts the tensor data format to the one required by the XLA convolution
// library.
xla::ConvolutionDimensionNumbers MakeConvolutionDimensionNumbers(
TensorFormat data_format, int num_spatial_dims) {
int num_dims = num_spatial_dims + 2;
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
xla::ConvolutionDimensionNumbers conv_dim_numbers;
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
conv_dim_numbers.add_input_spatial_dimensions(
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim));
}
conv_dim_numbers.set_input_batch_dimension(batch_dimension);
conv_dim_numbers.set_input_feature_dimension(feature_dimension);
return conv_dim_numbers;
}
} // namespace
xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(
int num_spatial_dims, bool depthwise,
tensorflow::OpKernelConstruction* ctx) {
ConvOpAttrs attrs;
attrs.num_spatial_dims = num_spatial_dims;
attrs.depthwise = depthwise;
TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
if (attrs.padding == tensorflow::EXPLICIT) {
TF_RETURN_IF_ERROR(
ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
}
string data_format;
TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
if (!FormatFromString(data_format, &attrs.data_format)) {
return errors::InvalidArgument("Invalid data format: ", data_format);
}
return attrs;
}
xla::StatusOr<xla::ConvOpAttrs> ConvOpAttrs::ToXla(
const TensorShape& input_shape, const TensorShape& filter_shape) const {
xla::ConvOpAttrs xla_attrs;
xla_attrs.depthwise = depthwise;
xla_attrs.num_spatial_dims = num_spatial_dims;
xla_attrs.dilations = dilations;
xla_attrs.strides = strides;
xla_attrs.data_format =
MakeConvolutionDimensionNumbers(data_format, num_spatial_dims);
if (padding == Padding::EXPLICIT) {
xla_attrs.explicit_paddings = explicit_paddings;
return xla_attrs;
}
int num_dims = num_spatial_dims + 2;
xla_attrs.explicit_paddings.resize(2 * num_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
int64 unused_output_size;
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
input_shape.dim_size(dim), filter_shape.dim_size(i), dilations.at(dim),
strides.at(dim), padding, &unused_output_size,
&xla_attrs.explicit_paddings[dim * 2],
&xla_attrs.explicit_paddings[dim * 2 + 1]));
}
return xla_attrs;
}
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_ATTRS_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_ATTRS_H_
#include "tensorflow/compiler/xla/client/lib/conv_op_helpers.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
// convolution.
struct ConvOpAttrs {
// Constructs a ConvOpAttrs, reading most of the attributes from `ctx`.
static xla::StatusOr<ConvOpAttrs> Create(
int num_spatial_dims, bool depthwise,
tensorflow::OpKernelConstruction* ctx);
// Converts to the format required by the XLA convolution helpers.
xla::StatusOr<xla::ConvOpAttrs> ToXla(const TensorShape& input_shape,
const TensorShape& filter_shape) const;
bool depthwise;
int num_spatial_dims;
std::vector<int32> dilations;
std::vector<int32> strides;
tensorflow::Padding padding;
std::vector<int64> explicit_paddings;
tensorflow::TensorFormat data_format;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_ATTRS_H_

View File

@ -1,70 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
#include <vector>
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
// This header exposes utilities for translating TensorFlow convolution ops into
// XLA ops.
//
// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g.
// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in
// this header to implement a new and exciting convolution op, for example a
// fused TensorFlow op that contains a convolution and other things.
namespace tensorflow {
// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
// convolution.
struct ConvOpAttrs {
// Constructs a ConvOpAttrs, reading most of the attributes from `ctx`.
static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise,
OpKernelConstruction* ctx);
bool depthwise;
int num_spatial_dims;
std::vector<int32> dilations;
std::vector<int32> strides;
Padding padding;
std::vector<int64> explicit_paddings;
TensorFormat data_format;
};
// Creates a new XLA forward or backward convolution with the given inputs and
// attributes.
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
xla::XlaOp conv_input,
xla::XlaOp filter,
const ConvOpAttrs& attrs);
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,
const ConvOpAttrs& attrs);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_

View File

@ -15,13 +15,14 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
#include "tensorflow/compiler/tf2xla/kernels/conv_op_attrs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/conv_op_helpers.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -51,8 +52,12 @@ class ConvOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
xla::StatusOr<xla::ConvOpAttrs> attrs =
attrs_.ToXla(ctx->InputShape(0), ctx->InputShape(1));
OP_REQUIRES_OK(ctx, attrs.status());
xla::StatusOr<xla::XlaOp> conv =
xla::MakeXlaForwardConvOp(ctx->op_kernel().type_string(), ctx->Input(0),
ctx->Input(1), attrs.ValueOrDie());
OP_REQUIRES_OK(ctx, conv.status());
ctx->SetOutput(0, conv.ValueOrDie());
}
@ -102,10 +107,13 @@ class ConvBackpropInputOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
xla::Shape input_shape =
TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
xla::StatusOr<xla::ConvOpAttrs> attrs =
attrs_.ToXla(input_tensor_shape, ctx->InputShape(1));
OP_REQUIRES_OK(ctx, attrs.status());
xla::StatusOr<xla::XlaOp> in_backprop =
MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
ctx->Input(1), ctx->Input(2), attrs_);
xla::StatusOr<xla::XlaOp> in_backprop = xla::MakeXlaBackpropInputConvOp(
ctx->op_kernel().type_string(), input_shape, ctx->Input(1),
ctx->Input(2), attrs.ValueOrDie());
OP_REQUIRES_OK(ctx, in_backprop.status());
ctx->SetOutput(0, in_backprop.ValueOrDie());
}
@ -160,10 +168,14 @@ class ConvBackpropFilterOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
xla::Shape filter_shape =
TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
xla::StatusOr<xla::ConvOpAttrs> attrs =
attrs_.ToXla(ctx->InputShape(0), filter_tensor_shape);
OP_REQUIRES_OK(ctx, attrs.status());
xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
ctx->Input(2), attrs_);
xla::StatusOr<xla::XlaOp> filter_backprop =
xla::MakeXlaBackpropFilterConvOp(ctx->op_kernel().type_string(),
ctx->Input(0), filter_shape,
ctx->Input(2), attrs.ValueOrDie());
OP_REQUIRES_OK(ctx, filter_backprop.status());
ctx->SetOutput(0, filter_backprop.ValueOrDie());
}

View File

@ -82,33 +82,71 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
}
// Builds a custom_call to a method named 'fake_quant_with_min_max_vars'.
// The method will be provided the input, the min/max range from the original
// TensorFlow op, and the num_bits and narrow_range attributes.
xla::StatusOr<xla::XlaOp> BuildFakeQuantCustomCall(
xla::XlaBuilder* b, xla::XlaOp input, xla::XlaOp input_min,
xla::XlaOp input_max, int num_bits, bool narrow_range) {
xla::XlaOp num_bits_arg =
XlaHelpers::IntegerLiteral(b, DataType::DT_INT32, num_bits);
xla::XlaOp narrow_range_arg = narrow_range
? XlaHelpers::One(b, DataType::DT_BOOL)
: XlaHelpers::Zero(b, DataType::DT_BOOL);
std::vector<xla::XlaOp> args = {input, input_min, input_max, num_bits_arg,
narrow_range_arg};
std::vector<xla::Shape> arg_shapes;
for (const xla::XlaOp& arg : args) {
TF_ASSIGN_OR_RETURN(xla::Shape arg_shape, b->GetShape(arg));
*arg_shape.mutable_layout() =
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
arg_shapes.push_back(std::move(arg_shape));
}
// Input and output shapes match exactly.
TF_ASSIGN_OR_RETURN(xla::Shape output_shape, b->GetShape(input));
return xla::CustomCallWithLayout(b, "fake_quant_with_min_max_vars", args,
output_shape, arg_shapes);
}
class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
num_bits_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
quant_min_ = narrow_range_ ? 1 : 0;
quant_max_ = (1 << num_bits_) - 1;
float input_min, input_max;
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max_));
CpuNudge(input_min_, input_max_, quant_min_, quant_max_, &nudged_input_min_,
&nudged_input_max_, &input_scale_);
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
xla::XlaBuilder* b = ctx->builder();
if (ctx->compiler()->options().allow_cpu_custom_calls &&
ctx->compiler()->options().custom_fake_quant_op_calls) {
xla::XlaOp custom_call_output =
b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
b, input,
XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_min_),
XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_max_),
num_bits_, narrow_range_));
ctx->SetOutput(0, custom_call_output);
return;
}
xla::XlaOp nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
xla::XlaOp nudged_input_max =
@ -121,6 +159,10 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
}
private:
int num_bits_;
bool narrow_range_;
float input_min_;
float input_max_;
float quant_min_;
float quant_max_;
float nudged_input_min_;
@ -184,25 +226,32 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
num_bits_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
quant_min_ = narrow_range_ ? 1 : 0;
quant_max_ = (1 << num_bits_) - 1;
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
xla::XlaOp input_min = ctx->Input(1);
xla::XlaOp input_max = ctx->Input(2);
xla::XlaBuilder* b = ctx->builder();
if (ctx->compiler()->options().allow_cpu_custom_calls &&
ctx->compiler()->options().custom_fake_quant_op_calls) {
xla::XlaOp custom_call_output =
b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
b, input, input_min, input_max, num_bits_, narrow_range_));
ctx->SetOutput(0, custom_call_output);
return;
}
xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
@ -213,6 +262,8 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
}
private:
int num_bits_;
bool narrow_range_;
float quant_min_;
float quant_max_;
};

View File

@ -31,7 +31,9 @@ limitations under the License.
namespace tensorflow {
XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min)
: XlaOpKernel(ctx), is_min_(is_min) {}
: XlaOpKernel(ctx),
is_min_(is_min),
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
const TensorShape input_shape = ctx->InputShape(0);
@ -64,11 +66,20 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output;
// One pass ArgMin/ArgMax is slow on GPUs.
if (is_min_) {
if (is_gpu_) {
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMin(input, index_xla_type, axis);
}
} else {
if (is_gpu_) {
output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMax(input, index_xla_type, axis);
}
}
ctx->SetOutput(0, output);
}

View File

@ -30,6 +30,7 @@ class XlaArgMinMaxOp : public XlaOpKernel {
private:
const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)?
const bool is_gpu_;
};
class XlaArgMaxOp : public XlaArgMinMaxOp {

View File

@ -27,7 +27,8 @@ namespace tensorflow {
// `bit_generator` and converted to the requested data type and range. This
// routine requires 2 32-bit integer seeds and currently only supports 'shape's
// of type F32, S32 and S64.
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval);
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.

View File

@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
TensorShape shape;
int64 product = 1;
int unknown_index = -1;
bool shape_has_zero_dim = false;
for (int d = 0; d < num_dims; ++d) {
const int32 size = shape_input[d];
if (size == -1) {
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
unknown_index, " and ", d));
unknown_index = d;
shape.AddDim(1);
} else if (size == 0) {
// We don't include zero-sized dimension in product, so that we can
// still calculate number of elements for non-zero-sized dimensions and
// therefore infer their shapes.
shape.AddDim(size);
shape_has_zero_dim = true;
} else {
OP_REQUIRES(ctx, size >= 0,
errors::InvalidArgument(
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
}
}
if (unknown_index != -1) {
int64 input_num_elements = 1;
bool input_has_zero_dim = false;
for (int dim = 0; dim < input_shape.dims(); dim++) {
// For zero dimension, we don't count it into `input_num_elements`
// unless `sizes` has no zero dimension, so we are still able to
// infer shapes for other dimensions.
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
input_num_elements *= input_shape.dim_size(dim);
} else {
input_has_zero_dim = true;
}
}
const int64 missing = input_num_elements / product;
if (!input_has_zero_dim) {
OP_REQUIRES(
ctx, product > 0,
errors::InvalidArgument("Reshape cannot infer the missing input size "
"for an empty tensor unless all specified "
"input sizes are non-zero"));
const int64 missing = input_shape.num_elements() / product;
OP_REQUIRES(
ctx, product * missing == input_shape.num_elements(),
ctx, product * missing == input_num_elements,
errors::InvalidArgument(
"Input to reshape is a tensor with ", input_shape.num_elements(),
"Input to reshape is a tensor with ", input_num_elements,
" values, but the requested shape requires a multiple of ",
product));
}
shape.set_dim(unknown_index, missing);
}
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),

View File

@ -16,6 +16,8 @@ limitations under the License.
// XLA-specific Shape Ops.
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -223,14 +225,33 @@ class ZerosLikeOp : public XlaOpKernel {
explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
if (IsTensorListInput(ctx, 0)) {
// Input is a TensorList.
// TODO(b/124707753): support nested TensorList.
xla::XlaOp tensor_list = ctx->Input(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(tensor_list, &shape));
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, GetTensorListPrimitiveType(tensor_list, &type));
xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, shape, type, &buffer));
xla::XlaOp push_index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tensor_list, &push_index));
xla::XlaOp output_list;
OP_REQUIRES_OK(ctx, BuildTensorList(buffer, push_index, &output_list));
ctx->SetTensorListOutput(0, output_list);
} else {
const TensorShape input_shape = ctx->InputShape(0);
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
}
}
};
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
REGISTER_XLA_OP(Name("ZerosLike").AllowVariantTypes(), ZerosLikeOp);
class OnesLikeOp : public XlaOpKernel {
public:

View File

@ -33,6 +33,22 @@ limitations under the License.
namespace tensorflow {
namespace {
xla::BitGeneratorTy GetBitGeneratorForDevice(
absl::string_view device_type_string) {
// The Philox algorithm may cause performance regression on other devices.
// Turn on the Philox algorithm for the CPU and GPU backends only.
if (device_type_string == DEVICE_GPU_XLA_JIT ||
device_type_string == DEVICE_CPU_XLA_JIT) {
return xla::PhiloxBitGenerator;
}
return xla::ThreeFryBitGenerator;
}
} // namespace
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
if (dtype == DT_BFLOAT16) {
xla::XlaBuilder* builder = input.builder();
@ -45,7 +61,8 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
}
}
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval) {
xla::XlaBuilder* builder = seeds.builder();
@ -58,14 +75,16 @@ xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
xla::PrimitiveType type = shape.element_type();
switch (type) {
case xla::F32:
return xla::UniformF32Distribution(key, initial_state,
xla::ThreeFryBitGenerator, minval,
maxval, shape)
return xla::UniformF32Distribution(
key, initial_state,
GetBitGeneratorForDevice(device_type_string), minval, maxval,
shape)
.value;
case xla::S32: // fall through
case xla::S64:
return UniformIntDistribution(key, initial_state,
xla::ThreeFryBitGenerator, minval, maxval,
return UniformIntDistribution(
key, initial_state,
GetBitGeneratorForDevice(device_type_string), minval, maxval,
shape)
.value;
break;
@ -82,7 +101,8 @@ namespace {
class StatelessRandomUniformOp : public XlaOpKernel {
public:
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
@ -100,8 +120,9 @@ class StatelessRandomUniformOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRngUniform(
seed, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::XlaOp uniform =
StatelessRngUniform(device_type_string_, seed, xla_shape,
xla::ConstantR0<float>(builder, 0.0),
xla::ConstantR0<float>(builder, 1.0));
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
ctx->SetOutput(0, uniform);
@ -109,6 +130,7 @@ class StatelessRandomUniformOp : public XlaOpKernel {
private:
DataType dtype_;
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
};
@ -123,7 +145,8 @@ REGISTER_XLA_OP(Name("StatelessRandomUniform")
class StatelessRandomUniformIntOp : public XlaOpKernel {
public:
explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
@ -150,13 +173,15 @@ class StatelessRandomUniformIntOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval);
xla::XlaOp uniform = StatelessRngUniform(device_type_string_, seed,
xla_shape, minval, maxval);
ctx->SetOutput(0, uniform);
}
private:
DataType dtype_;
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
};
@ -171,7 +196,8 @@ REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
class StatelessRandomNormalOp : public XlaOpKernel {
public:
explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
@ -195,8 +221,9 @@ class StatelessRandomNormalOp : public XlaOpKernel {
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp normal =
xla::NormalF32Distribution(key, initial_state,
xla::ThreeFryBitGenerator, xla_shape)
xla::NormalF32Distribution(
key, initial_state, GetBitGeneratorForDevice(device_type_string_),
xla_shape)
.value;
normal = MaybeConvertF32ToBF16(normal, dtype_);
ctx->SetOutput(0, normal);
@ -204,6 +231,7 @@ class StatelessRandomNormalOp : public XlaOpKernel {
private:
DataType dtype_;
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
};
@ -218,7 +246,8 @@ REGISTER_XLA_OP(Name("StatelessRandomNormal")
class StatelessTruncatedNormalOp : public XlaOpKernel {
public:
explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
@ -236,7 +265,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRngUniform(
seed, xla_shape,
device_type_string_, seed, xla_shape,
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
xla::One(builder, xla_shape.element_type()));
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
@ -246,6 +275,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
private:
DataType dtype_;
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
};

View File

@ -47,16 +47,19 @@ class TensorListLengthOp : public XlaOpKernel {
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index));
ctx->SetOutput(0, index);
TensorShape buffer_shape;
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &buffer_shape));
Tensor length_tensor(DT_INT32, {});
length_tensor.scalar<int32>()() =
static_cast<int32>(buffer_shape.dim_size(0));
ctx->SetConstantOutput(0, length_tensor);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
};
REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp);
REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
// Creates an empty list with size (leading_dim, *element_shape) if
// element_shape is known at compile time. Otherwise creates one with size

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape.h"
@ -35,6 +36,17 @@ Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
return Status::OK();
}
Status GetTensorListPrimitiveType(const xla::XlaOp& op,
xla::PrimitiveType* type) {
TF_RET_CHECK(op.builder());
TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape,
op.builder()->GetShape(op));
xla::Shape buffer_shape =
xla::ShapeUtil::GetTupleElementShape(list_tuple_shape, 0);
*type = buffer_shape.element_type();
return Status::OK();
}
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) {
TF_RET_CHECK(op.builder());
*buffer = xla::GetTupleElement(op, 0);
@ -97,4 +109,12 @@ Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
return BuildTensorList(new_buffer, push_index, output_list);
}
Status CreateZerosList(XlaOpKernelContext* ctx, const TensorShape& buffer_shape,
xla::PrimitiveType type, xla::XlaOp* list) {
auto zero =
xla::ConstantLiteral(ctx->builder(), xla::LiteralUtil::Zero(type));
*list = xla::Broadcast(zero, buffer_shape.dim_sizes());
return Status::OK();
}
} // namespace tensorflow

View File

@ -35,6 +35,10 @@ bool IsTensorListInput(XlaOpKernelContext* ctx, int index);
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
xla::XlaOp* output_list);
// Returns XLA PrimitiveType for the TensorList.
Status GetTensorListPrimitiveType(const xla::XlaOp& op,
xla::PrimitiveType* type);
// Returns the buffer for the TensorList.
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer);
@ -62,6 +66,10 @@ Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
const TensorShape& buffer_shape,
xla::XlaOp* output_list);
// Returns a TensorList filled with zero.
Status CreateZerosList(XlaOpKernelContext* ctx, const TensorShape& buffer_shape,
xla::PrimitiveType type, xla::XlaOp* list);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_

View File

@ -529,7 +529,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
int resource_index = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
if (IsTensorListInput(ctx, i)) {
ctx->SetTensorListOutput(i, xla::GetTupleElement(while_result, i));
} else {
ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
}
++resource_index;
} else {
break;

View File

@ -1,5 +1,5 @@
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
default_visibility = ["//tensorflow:internal"],
)
licenses(["notice"]) # Apache 2.0

View File

@ -306,26 +306,6 @@ dynamic_slice = gen_xla_ops.xla_dynamic_slice
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
einsum = gen_xla_ops.xla_einsum
@ops.RegisterGradient('XlaEinsum')
def _einsum_grad(op, grad):
equation = op.get_attr('equation')
inputs, output = equation.split('->')
left, right = inputs.split(',')
return [
gen_xla_ops.xla_einsum(
grad,
op.inputs[1],
equation='{},{}->{}'.format(output, right, left),
name=None),
gen_xla_ops.xla_einsum(
grad,
op.inputs[0],
equation='{},{}->{}'.format(output, left, right),
name=None)
]
# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
# the XLA-specific pad operator.
pad = gen_xla_ops.xla_pad

View File

@ -0,0 +1,729 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
#include <algorithm>
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
// Given original input types and argument index mapping, return the new input
// types.
std::vector<DataType> ShuffleInputDataTypeAttribute(
const std::vector<DataType>& in_types,
const std::vector<int>& index_mapping) {
std::vector<DataType> result(index_mapping.size());
for (int i = 0; i < in_types.size(); i++) {
result[index_mapping.at(i)] = in_types[i];
}
return result;
}
// Given original input types, check if we need to rewrite the function (by
// checking if all DT_RESOURCE inputs are in the end). If the function needs to
// be rewritten, `resource_input_count` will be set to number of DT_RESOURCE
// inputs, and `index_mapping` will hold a mapping for original input index to
// rearranged input index.
Status InputTypesNeedsRearrange(const std::vector<DataType>& in_types,
bool* need_rewrite, int* resource_input_count,
std::vector<int>* index_mapping) {
int first_resource_index = -1;
for (int i = 0; i < in_types.size(); i++) {
DataType type = in_types[i];
if (type == DT_RESOURCE) {
first_resource_index = i;
break;
}
}
if (first_resource_index == -1) {
// No resource input. No need to rewrite.
*need_rewrite = false;
return Status::OK();
}
*need_rewrite = false;
for (int i = first_resource_index + 1; i < in_types.size(); i++) {
if (in_types[i] != DT_RESOURCE) {
*need_rewrite = true;
break;
}
}
if (!*need_rewrite) {
return Status::OK();
}
*resource_input_count = 0;
for (int i = 0; i < in_types.size(); i++) {
DataType type = in_types[i];
if (type == DT_RESOURCE) {
++(*resource_input_count);
}
}
int non_resource_index = 0,
resource_index = in_types.size() - *resource_input_count;
index_mapping->resize(in_types.size());
for (int i = 0; i < in_types.size(); i++) {
if (in_types[i] != DT_RESOURCE) {
(*index_mapping)[i] = non_resource_index;
non_resource_index++;
} else {
(*index_mapping)[i] = resource_index;
resource_index++;
}
}
return Status::OK();
}
// Given mapping between original input index and rearranged input index,
// reorder input edges for the node.
Status ReorderInputEdges(Graph* g, Node* n,
const std::vector<int>& index_mapping) {
std::vector<const Edge*> input_edges;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
continue;
}
input_edges.push_back(e);
}
for (const Edge* e : input_edges) {
Node* src = e->src();
int src_output = e->src_output();
int dst_input = e->dst_input();
int new_dst_input = index_mapping.at(dst_input);
g->RemoveEdge(e);
g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
}
return Status::OK();
}
// For While node, given mapping between original input index and rearranged
// input index, reorder output edges for the node. DT_RESOURCE outputs are
// removed from the node and we will use the node's corresponding input for the
// edge.
Status ReorderOutputEdges(Graph* g, Node* n, int input_count,
int resource_input_count,
const std::vector<int>& index_mapping) {
std::vector<const Edge*> output_edges;
for (const Edge* e : n->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
output_edges.push_back(e);
}
for (const Edge* e : output_edges) {
int src_output = e->src_output();
int new_src_output = index_mapping.at(src_output);
Node* dst = e->dst();
int dst_input = e->dst_input();
g->RemoveEdge(e);
if (new_src_output < input_count - resource_input_count) {
g->AddEdge(n, new_src_output, dst, dst_input);
} else {
const Edge* input_edge;
TF_RETURN_IF_ERROR(n->input_edge(new_src_output, &input_edge));
g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
}
}
return Status::OK();
}
// Given mapping between original input index and rearranged input index, change
// "index" attribute for _Arg nodes.
void RearrangeArgNodes(gtl::InlinedVector<Node*, 4>* arg_nodes, // non-absl ok
const std::vector<int>& index_mapping) {
for (int i = 0; i < arg_nodes->size(); i++) {
Node* n = (*arg_nodes)[i];
int new_index = index_mapping.at(i);
n->ClearAttr("index");
n->AddAttr("index", new_index);
}
}
// Given all _Retval nodes in the function, return if we need to rewrite the
// function (by checking if we have DT_RESOURCE return values). If we need to
// rewrite the function, `retval_index_mapping` will hold the mapping from
// original _Retval to rearranged _Retval, and `resource_retval_to_arg` will
// hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we
// assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly.
Status CalculateRetvalRearrange(
const gtl::InlinedVector<Node*, 4>& ret_nodes, // non-absl ok
std::map<int, int>* retval_index_mapping,
std::map<int, int>* resource_retval_to_arg) {
for (int i = 0; i < ret_nodes.size(); i++) {
Node* n = ret_nodes[i];
DataType t;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &t));
if (t != DT_RESOURCE) {
int new_retval_index = retval_index_mapping->size();
retval_index_mapping->insert(std::make_pair(i, new_retval_index));
continue;
}
const Edge* e;
TF_RETURN_IF_ERROR(n->input_edge(0, &e));
if (!e->src()->IsArg()) {
return errors::Unimplemented(
"Resource _Retval node's input does not come from _Arg "
"directly: ",
e->DebugString());
}
Node* arg = e->src();
int src_index;
TF_RETURN_IF_ERROR(GetNodeAttr(arg->def(), "index", &src_index));
resource_retval_to_arg->insert(std::make_pair(i, src_index));
}
return Status::OK();
}
// Given original output types and return value index mapping, return the new
// output types. Notice that DT_RESOURCE will be removed.
std::vector<DataType> ShuffleOutputDataTypeAttribute(
const std::vector<DataType>& out_types,
const std::map<int, int>& index_mapping) {
std::vector<DataType> result(index_mapping.size());
for (int i = 0; i < out_types.size(); i++) {
auto iter = index_mapping.find(i);
if (iter != index_mapping.end()) {
result[iter->second] = out_types[i];
}
}
return result;
}
// For StatefulPartitionedCall node, given mapping between original input index
// and rearranged input index, reorder output edges for the node. DT_RESOURCE
// outputs are removed from the node and we will use the node's corresponding
// input for the edge.
Status RearrangeOutputEdges(Node* n, Graph* g,
const std::map<int, int>& retval_index_mapping,
const std::map<int, int>& resource_retval_to_arg) {
std::vector<const Edge*> out_edges;
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
out_edges.push_back(e);
}
}
for (const Edge* e : out_edges) {
Node* dst = e->dst();
int dst_input = e->dst_input();
int src_output = e->src_output();
auto iter = retval_index_mapping.find(src_output);
if (iter == retval_index_mapping.end()) {
TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
resource_retval_to_arg.end());
g->RemoveEdge(e);
const Edge* input_edge;
TF_RETURN_IF_ERROR(
n->input_edge(resource_retval_to_arg.at(src_output), &input_edge));
g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
} else {
g->RemoveEdge(e);
g->AddEdge(n, iter->second, dst, dst_input);
}
}
return Status::OK();
}
// Given mapping between original output index and rearranged output index,
// change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval
// nodes will be removed.
void RearrangeRetvalNodes(
const gtl::InlinedVector<Node*, 4>& ret_nodes, // non-absl ok
Graph* g, const std::map<int, int>& retval_index_mapping) {
for (int i = 0; i < ret_nodes.size(); i++) {
Node* n = ret_nodes[i];
auto iter = retval_index_mapping.find(i);
if (iter == retval_index_mapping.end()) {
g->RemoveNode(n);
} else {
n->ClearAttr("index");
n->AddAttr("index", iter->second);
}
}
}
Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
// Check if this While node needs rewrite.
std::vector<DataType> types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types));
bool input_need_rearrange;
int resource_input_count;
std::vector<int> index_mapping;
TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
types, &input_need_rearrange, &resource_input_count, &index_mapping));
if (!input_need_rearrange) {
*node_rewritten = false;
return Status::OK();
}
*node_rewritten = true;
// Modify "T" attribute for this While node.
std::vector<DataType> new_types =
ShuffleInputDataTypeAttribute(types, index_mapping);
n->ClearAttr("T");
n->AddAttr("T", new_types);
// Reorder input and output edges.
TF_RETURN_IF_ERROR(ReorderInputEdges(g, n, index_mapping));
TF_RETURN_IF_ERROR(ReorderOutputEdges(g, n, types.size(),
resource_input_count, index_mapping));
// Modify cond and body functions.
for (auto const& attr_name : std::vector<string>{"cond", "body"}) {
NameAttrList attr_value;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value));
const FunctionDef* fdef = fld->Find(attr_value.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
if (attr_name == "body") {
for (int i = 0; i < fbody->ret_nodes.size(); i++) {
Node* n = fbody->ret_nodes[i];
int new_index = index_mapping.at(i);
if (new_index < types.size() - resource_input_count) {
n->ClearAttr("index");
n->AddAttr("index", new_index);
} else {
fbody->graph->RemoveNode(n);
}
}
}
// Save the new FunctionDef.
FunctionDef new_fdef;
string new_name =
fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change node to use rewritten function.
attr_value.set_name(new_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, attr_value);
}
return Status::OK();
}
Status MaybeRewriteCallNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
// This node needs rewrite when either of these is true:
// 1) Tin has DT_RESOURCE which requires rearrange;
// 2) Tout has DT_RESOURCE.
std::vector<DataType> in_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &in_types));
bool input_need_rearrange;
int resource_input_count;
std::vector<int> index_mapping;
TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
in_types, &input_need_rearrange, &resource_input_count, &index_mapping));
std::vector<DataType> out_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tout", &out_types));
bool has_resource_output = std::find(out_types.begin(), out_types.end(),
DT_RESOURCE) != out_types.end();
if (!resource_input_count && !has_resource_output) {
*node_rewritten = false;
return Status::OK();
}
*node_rewritten = true;
string attr_name = "f";
NameAttrList f;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
const FunctionDef* fdef = fld->Find(f.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
if (input_need_rearrange) {
// Reorder input edges.
TF_RETURN_IF_ERROR(ReorderInputEdges(g, n, index_mapping));
// Change Tin attribute.
std::vector<DataType> new_in_types =
ShuffleInputDataTypeAttribute(in_types, index_mapping);
n->ClearAttr("Tin");
n->AddAttr("Tin", new_in_types);
// Change _Arg node index.
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
}
if (has_resource_output) {
// Resource _Retval must come from resource _Arg directly, or we do not
// support it.
std::map<int, int> resource_retval_to_arg, retval_index_mapping;
TF_RETURN_IF_ERROR(CalculateRetvalRearrange(
fbody->ret_nodes, &retval_index_mapping, &resource_retval_to_arg));
// Rearrange output edges.
TF_RETURN_IF_ERROR(RearrangeOutputEdges(n, g, retval_index_mapping,
resource_retval_to_arg));
// Change Tout attribute for the node.
std::vector<DataType> new_out_types =
ShuffleOutputDataTypeAttribute(out_types, retval_index_mapping);
n->ClearAttr("Tout");
n->AddAttr("Tout", new_out_types);
// Change index for _Retval nodes.
RearrangeRetvalNodes(fbody->ret_nodes, fbody->graph, retval_index_mapping);
}
// Save the new FunctionDef.
FunctionDef new_fdef;
string new_name =
fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change node to use rewritten function.
f.set_name(new_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, f);
return Status::OK();
}
Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
// This node needs rewrite when either of these is true:
// 1) Tin has DT_RESOURCE which requires rearrange;
// 2) Tout has DT_RESOURCE.
std::vector<DataType> in_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &in_types));
bool input_need_rearrange;
int resource_input_count;
std::vector<int> index_mapping;
TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
in_types, &input_need_rearrange, &resource_input_count, &index_mapping));
std::vector<DataType> out_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tout", &out_types));
bool has_resource_output = std::find(out_types.begin(), out_types.end(),
DT_RESOURCE) != out_types.end();
if (!input_need_rearrange && !has_resource_output) {
*node_rewritten = false;
return Status::OK();
}
*node_rewritten = true;
if (input_need_rearrange) {
// Reorder input edges.
std::vector<const Edge*> input_edges;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge() || e->dst_input() == 0) {
continue;
}
input_edges.push_back(e);
}
for (const Edge* e : input_edges) {
Node* src = e->src();
int src_output = e->src_output();
int dst_input = e->dst_input();
int new_dst_input = index_mapping.at(dst_input - 1) + 1;
g->RemoveEdge(e);
g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
}
// Change Tin attribute.
std::vector<DataType> new_in_types =
ShuffleInputDataTypeAttribute(in_types, index_mapping);
n->ClearAttr("Tin");
n->AddAttr("Tin", new_in_types);
}
std::map<int, int> resource_retval_to_arg, retval_index_mapping;
for (auto const& attr_name :
std::vector<string>{"then_branch", "else_branch"}) {
NameAttrList f;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
const FunctionDef* fdef = fld->Find(f.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
if (input_need_rearrange) {
// Change _Arg node index.
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
}
if (has_resource_output) {
// Resource _Retval must come from resource _Arg directly, or we do
// not support it.
TF_RETURN_IF_ERROR(CalculateRetvalRearrange(
fbody->ret_nodes, &retval_index_mapping, &resource_retval_to_arg));
// Change index for _Retval nodes.
RearrangeRetvalNodes(fbody->ret_nodes, fbody->graph,
retval_index_mapping);
}
// Save the new FunctionDef.
FunctionDef new_fdef;
string new_name =
fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change node to use rewritten function.
f.set_name(new_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, f);
}
if (has_resource_output) {
// Rearrange output edges.
std::vector<const Edge*> out_edges;
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
out_edges.push_back(e);
}
}
for (const Edge* e : out_edges) {
Node* dst = e->dst();
int dst_input = e->dst_input();
int src_output = e->src_output();
auto iter = retval_index_mapping.find(src_output);
if (iter == retval_index_mapping.end()) {
TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
resource_retval_to_arg.end());
g->RemoveEdge(e);
const Edge* input_edge;
TF_RETURN_IF_ERROR(n->input_edge(
resource_retval_to_arg.at(src_output) + 1, &input_edge));
g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
} else {
g->RemoveEdge(e);
g->AddEdge(n, iter->second, dst, dst_input);
}
}
// Change Tout attribute for the node.
std::vector<DataType> new_out_types =
ShuffleOutputDataTypeAttribute(out_types, retval_index_mapping);
n->ClearAttr("Tout");
n->AddAttr("Tout", new_out_types);
}
return Status::OK();
}
} // namespace
Status RearrangeFunctionArgumentForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
bool* modified) {
*modified = false;
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
Status ret_status = Status::OK();
auto cleanup_handle = gtl::MakeCleanup([&]() {
auto s = flr->ReleaseHandle(handle);
if (!s.ok()) {
ret_status.Update(s);
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* g = body->graph;
// If any node has associated functions, rewrite them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, fld);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
}
}
for (auto iter : nodes_to_associated_functions) {
Node* n = iter.first;
auto associated_functions = iter.second;
for (auto& associated_function : associated_functions) {
string name = associated_function.func_name();
string canonicalized_name =
Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
bool function_modified;
if (iter != canonicalized_name_to_new_name->end()) {
// If we already processed this function, check if it was rewritten. If
// the function was rewritten, the entry will be non-empty. Otherwise
// the entry will be empty.
function_modified = iter->second.has_value();
if (function_modified) {
new_name = iter->second.value();
}
} else {
if (associated_function.type() ==
AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
// For SymbolicGradient, `name` is always "SymbolicGradient",
// which is not very informative. Use node name instead.
new_name =
fld->UniqueFunctionName(absl::StrCat(n->name(), "_rearrange_"));
} else {
new_name = fld->UniqueFunctionName(absl::StrCat(name, "_rearrange_"));
}
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
name, new_name, associated_function.attrs(), fld, flr,
canonicalized_name_to_new_name, &function_modified));
if (function_modified) {
// If the function was rewritten, add an non-empty entry. So later we
// know we have processed this function, and it was rewritten into
// another function.
(*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
} else {
// If the function was not rewritten, add an empty entry. So later
// we know we have processed this function, and it does not need to be
// rewritten.
(*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
}
}
if (function_modified) {
*modified = true;
// Notice that if "n" is a function call, RewriteAssociatedFunction()
// will delete it and create a new node instead, making "n" an invalid
// pointer. That's fine because in that case, associated_functions will
// only have one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
g, n, fld, associated_function, new_name));
}
}
}
for (Node* n : g->nodes()) {
if (n->type_string() == "While") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
} else if (n->type_string() == "StatefulPartitionedCall") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteCallNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
} else if (n->type_string() == "If") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteIfNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
}
}
if (*modified) {
// Add rewritten FunctionDef into library.
FunctionDef functionalized_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
if (func_name == new_func_name) {
VLOG(2) << "Replacing function " << func_name;
TF_RETURN_IF_ERROR(
fld->ReplaceFunction(new_func_name, functionalized_fdef));
} else {
VLOG(2) << "Adding function " << new_func_name;
TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
}
}
return ret_status;
} // namespace tensorflow
Status RearrangeFunctionArgumentPass::Run(
const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_before", *graph,
options.flib_def);
}
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
/*device_mgr=*/nullptr, options.session_options->env,
TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
// Find XLA compile ops and its corresponding FunctionDef.
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
new std::map<string, string>{
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
{"TPUReplicate", "computation"},
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
{"XlaLaunch", "function"},
};
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
continue;
}
const string func_attr = it->second;
NameAttrList func;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
VLOG(2) << "Graph has node " << n->type_string()
<< ". Corresponding function: " << func.name();
string new_func_name = options.flib_def->UniqueFunctionName(
absl::StrCat(func.name(), "_rearrange_"));
bool modified = false;
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
func.name(), new_func_name, func.attr(), options.flib_def, flr,
&canonicalized_name_to_new_name, &modified));
if (modified) {
n->ClearAttr(func_attr);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
}
}
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_after", *graph,
options.flib_def);
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_
#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// For the function with `func_name`, rewrite any
// StatefulPartitionedCall/If/While node that does not satisfy the rules.
// We will rewrite related FunctionDef to rearrange arguments and return values,
// also adjust node's input/output edges accordingly.
Status RearrangeFunctionArgumentForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
bool* modified);
// TF/XLA bridge expects FunctionDef to satisfy the following rules:
// 1. DT_RESOURCE arguments are always in the last;
// 2. Do not return DT_RESOURCE as return values.
// But functions defined by Tensorflow might not satisfy them.
// This rewrite pass rewrites the function for TPUCompile/XlaLaunch node
// to follow the rules, using RearrangeFunctionArgumentForFunction() above.
class RearrangeFunctionArgumentPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,17 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/accuracy/file_reader_stage.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
namespace tensorflow {
namespace metrics {
void FileReaderStage::AddToGraph(const Scope& scope, const Input& input) {
if (!scope.ok()) return;
Scope s = scope.WithOpName(name());
this->stage_output_ = ops::ReadFile(s.WithOpName(output_name()), input);
}
} // namespace metrics
// This pass is required for some AOT backends and all JIT backends, so this
// file exists as a separate lib and will be linked to both AOT and JIT.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 28,
RearrangeFunctionArgumentPass);
} // namespace tensorflow

View File

@ -295,6 +295,8 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
compiler_options.custom_fake_quant_op_calls =
config.conversion_options().custom_fake_quant_op_calls();
XlaCompiler compiler(compiler_options);
XlaCompiler::CompilationResult result;

View File

@ -53,6 +53,14 @@ message Variable {
bool readonly = 5;
}
// Options used during the conversion and compilation process.
message ConversionOptions {
// When true tf.fake_quant_* ops will be emitted as custom calls to a
// 'fake_quant_with_min_max_vars' function accepting the input, min, max,
// num_bits, and narrow_range values as runtime arguments.
bool custom_fake_quant_op_calls = 1;
}
// Config represents configuration information for tf2xla conversion.
message Config {
// Each feed is a positional input argument for the generated computation.
@ -63,4 +71,6 @@ message Config {
repeated Fetch fetch = 2;
// Each variable is a named input and output of the generated computation.
repeated Variable variable = 3;
// Optional conversion options.
ConversionOptions conversion_options = 4;
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
@ -200,8 +201,13 @@ Status BuildComputation(
output.shape = output.constant_value.shape();
break;
case XlaExpression::Kind::kTensorList:
TF_FALLTHROUGH_INTENDED;
case XlaExpression::Kind::kTensorList: {
output.is_tensor_list = true;
xla::XlaOp value = retval.handle();
elems.push_back(value);
break;
}
case XlaExpression::Kind::kXlaOp: {
output.is_constant = false;
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
@ -403,6 +409,8 @@ string XlaCompiler::Argument::HumanString() const {
}
case kParameter:
return absl::StrCat("kind=parameter", common);
case kTensorList:
return absl::StrCat("kind=tensorlist", common);
case kToken:
return absl::StrCat("token", common);
}
@ -641,6 +649,11 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
}
return Status::OK();
}
case XlaCompiler::Argument::kTensorList: {
TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
*xla_shape = absl::get<xla::Shape>(arg.shape);
return Status::OK();
}
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
@ -744,6 +757,7 @@ Status XlaCompiler::BuildArguments(
break;
}
case XlaCompiler::Argument::kParameter:
case XlaCompiler::Argument::kTensorList:
case XlaCompiler::Argument::kToken: {
input_to_args->push_back(i);
break;
@ -902,6 +916,10 @@ Status XlaCompiler::BuildArguments(
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
}
break;
case XlaCompiler::Argument::kTensorList: {
arg_expression = XlaExpression::TensorList(arg_handles[i]);
break;
}
case XlaCompiler::Argument::kToken: {
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
break;

View File

@ -116,6 +116,9 @@ class XlaCompiler {
// Argument is an XLA token.
kToken,
// Argument is a TensorList.
kTensorList,
};
Kind kind = kInvalid;
@ -226,6 +229,9 @@ class XlaCompiler {
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
// the index of the input that contains the resource.
int input_index;
// Whether this output is a TensorList.
bool is_tensor_list = false;
};
// Describes a variable write side effect of the computation.
@ -305,6 +311,12 @@ class XlaCompiler {
// for CPU.
bool allow_cpu_custom_calls = false;
// If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_*
// ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars'
// function accepting the input, min, max, num_bits, and narrow_range values
// as runtime arguments.
bool custom_fake_quant_op_calls = false;
// If set, the XLA representation of variables represented to XLA as the
// shape given by this shape function. Variables are reshaped to this shape
// on write, and reshaped to their original shape on read.

View File

@ -14,10 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@ -35,7 +40,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
@ -1498,5 +1505,76 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
}
}
TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
// Build cond fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
auto result = ops::Const<bool>(scope, {true}, {});
ops::_Retval(scope.WithOpName("ret"), result, 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
// Build body fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
ops::_Retval(scope.WithOpName("ret"), arg, 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
Scope scope = Scope::NewRootScope().ExitOnError();
auto element_shape = ops::Const<int32>(scope, {1}, {1});
auto max_elements = ops::Const<int32>(scope, {10}, {});
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
std::initializer_list<Output> out = {arg, arg};
auto add_n = ops::AddN(scope, out);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("cond");
body_fn.set_name("body");
auto while_op =
ops::While(scope, std::initializer_list<Input>{arg}, cond_fn, body_fn);
auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add_n, 0);
auto ret1 = ops::_Retval(scope.WithOpName("ret1"), while_op.output[0], 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kTensorList;
xla::Shape tensor_list_element_shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{1},
&tensor_list_element_shape));
xla::Shape index_shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape));
std::vector<xla::Shape> shapes{tensor_list_element_shape, index_shape};
xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes);
args[0].shape = arg_shape;
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.flib_def = &flib_def;
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
ASSERT_EQ(result.outputs.size(), 2);
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
ASSERT_TRUE(output0.is_tensor_list);
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
ASSERT_TRUE(output1.is_tensor_list);
}
} // namespace
} // namespace tensorflow

View File

@ -18,7 +18,8 @@ package_group(
],
)
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
# Required for open-source build.
load("//tensorflow:tensorflow.bzl", "cc_header_only_library") # @unused
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load(
@ -85,8 +86,8 @@ cc_library(
],
visibility = [":friends"],
deps = [
":debug_options_flags",
":xla_proto",
"//tensorflow/compiler/xla:debug_options_flags",
],
)
@ -190,6 +191,7 @@ cc_library(
":xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@ -812,7 +814,7 @@ cc_library(
hdrs = ["parse_flags_from_env.h"],
deps =
[
"//tensorflow/compiler/xla:types",
":types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
@ -827,7 +829,7 @@ tf_cc_test(
deps =
[
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
":types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -847,7 +849,7 @@ cc_library(
[
":parse_flags_from_env",
":status",
"//tensorflow/compiler/xla:xla_proto",
":xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"@com_google_absl//absl/strings",
@ -871,7 +873,7 @@ tf_cc_test(
],
deps =
[
"//tensorflow/compiler/xla:xla_proto",
":xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",

View File

@ -121,6 +121,21 @@ cc_library(
],
)
cc_library(
name = "conv_op_helpers",
srcs = ["conv_op_helpers.cc"],
hdrs = ["conv_op_helpers.h"],
deps = [
":arithmetic",
":constants",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "loops",
srcs = ["loops.cc"],

View File

@ -125,8 +125,60 @@ XlaOp Any(XlaOp predicates) {
namespace {
XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
PrimitiveType value_type,
PrimitiveType index_type, bool is_min) {
auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
XlaBuilder* b = sub_builder.get();
XlaOp lhs_value =
Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
XlaOp lhs_index =
Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
XlaOp rhs_value =
Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
XlaOp rhs_index =
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value);
XlaOp max = Select(cmp, lhs_value, rhs_value);
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
Tuple(b, {max, arg_max});
return b->Build().ConsumeValueOrDie();
}
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
XlaOp value_init_value;
if (is_min) {
value_init_value = MaxValue(builder, input_shape.element_type());
} else {
value_init_value = MinValue(builder, input_shape.element_type());
}
int64 dimension_size = input_shape.dimensions(axis);
auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
XlaOp index_init_value = Zero(builder, index_type);
auto iota_shape = input_shape;
iota_shape.set_element_type(index_type);
XlaOp iota = Iota(builder, iota_shape, axis);
XlaComputation reducer = CreateMinMaxComputation(
builder, input_shape.element_type(), index_type, is_min);
XlaOp max_argmax = Reduce(builder, {input, iota},
{value_init_value, index_init_value}, reducer,
/*dimensions_to_reduce=*/{axis});
XlaOp argmax = GetTupleElement(max_argmax, 1);
if (index_type != output_type) {
argmax = ConvertElementType(argmax, output_type);
}
return argmax;
});
}
XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
bool is_min) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
XlaOp init_value;
@ -172,7 +224,6 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
/*dimensions_to_reduce=*/{axis});
});
}
} // namespace
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
@ -183,4 +234,11 @@ XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
}
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false);
}
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true);
}
} // namespace xla

View File

@ -60,10 +60,12 @@ XlaOp Any(XlaOp predicates);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis);
// Returns the argmin of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis);
} // namespace xla

View File

@ -15,38 +15,24 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
#include "tensorflow/compiler/xla/client/lib/conv_op_helpers.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
namespace xla {
namespace {
// Returns the expanded size of a filter used for depthwise convolution.
// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
Shape ExpandedFilterShapeForDepthwiseConvolution(const Shape& shape) {
int num_dims = shape.dimensions_size();
CHECK_GE(num_dims, 2); // Crash OK
xla::Shape expanded_shape = shape;
Shape expanded_shape = shape;
expanded_shape.set_dimensions(
num_dims - 1,
shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
@ -54,40 +40,43 @@ xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
}
// Returns the transposed filter for use in BackpropInput of group convolution.
xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput(
const xla::XlaOp& filter, const xla::Shape& filter_shape, int64 num_groups,
XlaOp TransposeFilterForGroupConvolutionBackpropInput(const XlaOp& filter,
const Shape& filter_shape,
int64 num_groups,
int num_spatial_dims) {
// 1. Reshape from [H, W, ..., filter_in_depth, out_depth] to [H, W, ...,
// filter_in_depth, G, out_depth / G]
int num_dims = filter_shape.dimensions_size();
CHECK_GE(num_dims, 2); // Crash OK
xla::Shape new_shape = filter_shape;
Shape new_shape = filter_shape;
new_shape.set_dimensions(num_dims - 1, num_groups);
new_shape.add_dimensions(filter_shape.dimensions(num_dims - 1) / num_groups);
xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions());
XlaOp result = Reshape(filter, new_shape.dimensions());
// 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]
std::vector<int64> transpose_dims(num_dims + 1);
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
std::swap(transpose_dims[num_spatial_dims],
transpose_dims[num_spatial_dims + 1]);
result = xla::Transpose(result, transpose_dims);
result = Transpose(result, transpose_dims);
// 3. Reshape to [H, W, ..., in_depth, out_depth / G]
result = xla::Collapse(result, {num_spatial_dims, num_spatial_dims + 1});
result = Collapse(result, {num_spatial_dims, num_spatial_dims + 1});
return result;
}
// Returns the transposed input for use in BackpropFilter of group convolution.
xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter(
const xla::XlaOp& input, const xla::Shape& input_shape, int64 num_groups,
int batch_dim, int depth_dim) {
XlaOp TransposeInputForGroupConvolutionBackpropFilter(const XlaOp& input,
const Shape& input_shape,
int64 num_groups,
int batch_dim,
int depth_dim) {
// 1. Reshape the depth_dim C into [G, C/G]
int num_dims = input_shape.dimensions_size();
std::vector<int64> reshape_dims = input_shape.dimensions();
reshape_dims[depth_dim] = reshape_dims[depth_dim] / num_groups;
reshape_dims.insert(reshape_dims.begin() + depth_dim, num_groups);
xla::XlaOp result = xla::Reshape(input, reshape_dims);
XlaOp result = Reshape(input, reshape_dims);
// 2. Transpose G to the axis before N, e.g.: [G, N, H, W, C/G]
std::vector<int64> transpose_dims(num_dims + 1);
@ -97,10 +86,10 @@ xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter(
transpose_dims.insert(
transpose_dims.begin() + batch_dim,
depth_dim); // e.g.: [3, 0, 1, 2, 4] -> [G, N, H, W, C/G]
result = xla::Transpose(result, transpose_dims);
result = Transpose(result, transpose_dims);
// 3. Merge [G, N] to [G*N]
result = xla::Collapse(result, {batch_dim, batch_dim + 1});
result = Collapse(result, {batch_dim, batch_dim + 1});
return result;
}
@ -144,9 +133,8 @@ xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter(
//
// Finally compare A and B and return the result at the beginning of the
// comment.
xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
xla::XlaBuilder* builder) {
xla::Shape expanded_filter_shape =
XlaOp CreateExpandedFilterMask(const Shape& filter_shape, XlaBuilder* builder) {
Shape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
int64 depthwise_multiplier =
filter_shape.dimensions(filter_shape.dimensions_size() - 1);
@ -156,61 +144,60 @@ xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
// with the iota dimension chosen as the expanded output feature dimension.
std::vector<int64> iota_dimensions(expanded_filter_shape.dimensions().begin(),
expanded_filter_shape.dimensions().end());
xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions);
xla::XlaOp input_feature_iota = xla::Iota(
builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2);
xla::XlaOp expanded_feature_iota = xla::Iota(
builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1);
Shape iota_shape = ShapeUtil::MakeShape(S32, iota_dimensions);
XlaOp input_feature_iota =
Iota(builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2);
XlaOp expanded_feature_iota =
Iota(builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1);
// Divide 'expanded_feature_iota' by the depthwise_multiplier to create
// [0 0 1 1 2 2] ... in the example in the function comment.
expanded_feature_iota =
xla::Div(expanded_feature_iota,
XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
depthwise_multiplier));
expanded_feature_iota = Div(
expanded_feature_iota,
ConstantR0WithType(builder, PrimitiveType::S32, depthwise_multiplier));
// Compare 'input_feature_iota' with 'expanded_feature_iota' to create a
// diagonal predicate.
return xla::Eq(expanded_feature_iota, input_feature_iota);
return Eq(expanded_feature_iota, input_feature_iota);
}
// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
// build a depthwise convolution.
xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
const xla::XlaOp& filter) {
XlaOp ReshapeFilterForDepthwiseConvolution(const Shape& filter_shape,
const XlaOp& filter) {
int64 input_feature_dim = filter_shape.dimensions_size() - 2;
int64 output_feature_dim = filter_shape.dimensions_size() - 1;
int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
int64 input_feature = filter_shape.dimensions(input_feature_dim);
// Create a [H, W, ..., 1, N*M] reshape of the filter.
xla::Shape implicit_broadcast_filter_shape = filter_shape;
Shape implicit_broadcast_filter_shape = filter_shape;
implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
implicit_broadcast_filter_shape.set_dimensions(
output_feature_dim, depthwise_multiplier * input_feature);
return xla::Reshape(
filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
return Reshape(filter,
AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
}
// Reduces the results of the convolution with an expanded filter to the
// non-expanded filter.
xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
const xla::XlaOp& filter_backprop,
xla::XlaBuilder* builder) {
XlaOp ContractFilterForDepthwiseBackprop(const Shape& filter_shape,
const XlaOp& filter_backprop,
XlaBuilder* builder) {
auto masked_expanded_filter =
xla::Select(CreateExpandedFilterMask(filter_shape, builder),
filter_backprop, xla::ZerosLike(filter_backprop));
Select(CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
ZerosLike(filter_backprop));
auto elem_type = filter_shape.element_type();
return xla::Reshape(
return Reshape(
// This reduce does not need inputs to be converted with
// XlaHelpers::SumAccumulationType() since the select above guarantees
// that only one element is non zero, so there cannot be accumulated
// precision error.
xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
Reduce(masked_expanded_filter, Zero(builder, elem_type),
CreateScalarAddComputation(elem_type, builder),
{filter_shape.dimensions_size() - 2}),
xla::AsInt64Slice(filter_shape.dimensions()));
AsInt64Slice(filter_shape.dimensions()));
}
// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
@ -218,107 +205,189 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
Status CheckConvAttrs(const ConvOpAttrs& attrs) {
const int num_dims = attrs.num_spatial_dims + 2;
if (attrs.strides.size() != num_dims) {
return errors::InvalidArgument("Sliding window strides field must specify ",
num_dims, " dimensions");
return InvalidArgument(
"Sliding window strides field must specify %d dimensions", num_dims);
}
int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
int batch_dim = attrs.data_format.input_batch_dimension();
int feature_dim = attrs.data_format.input_feature_dimension();
if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
return errors::Unimplemented(
return Unimplemented(
"Current implementation does not yet support strides in the batch and "
"depth dimensions.");
}
if (attrs.dilations.size() != num_dims) {
return errors::InvalidArgument("Dilations field must specify ", num_dims,
" dimensions");
return InvalidArgument("Dilations field must specify %d dimensions",
num_dims);
}
if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
return errors::Unimplemented(
return Unimplemented(
"Current implementation does not support dilations in the batch and "
"depth dimensions.");
}
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
int input_dim = attrs.data_format.input_spatial_dimensions(i);
if (attrs.dilations[input_dim] < 1) {
return errors::Unimplemented("Dilation values must be positive; ", i,
"th spatial dimension had dilation ",
attrs.dilations[input_dim]);
return Unimplemented(
"Dilation values must be positive; %dth spatial dimension had "
"dilation %d",
i, attrs.dilations[input_dim]);
}
}
return Status::OK();
}
// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
// to TensorShapes.
Status ConvBackpropComputeDimensionsV2XlaShapes(
StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
// Information about a single spatial dimension for a convolution
// backpropagation.
struct ConvBackpropSpatialDimension {
int64 input_size;
int64 filter_size;
int64 output_size;
int64 stride;
int64 dilation;
// Output size after scaling by the stride.
int64 expanded_output_size;
// Number of padding elements to be added before/after this dimension of
// the input when computing Conv?DBackpropInput.
int64 pad_before, pad_after;
};
// Computed dimensions for a backwards convolution.
struct ConvBackpropDimensions {
// Information about each spatial dimension.
std::vector<ConvBackpropSpatialDimension> spatial_dims;
// Batch size.
int64 batch_size;
// Input and output feature depth.
int64 in_depth, out_depth;
};
Status ConvBackpropExtractAndVerifyDimension(
absl::Span<const int64> input_shape, absl::Span<const int64> filter_shape,
absl::Span<const int64> output_shape, absl::Span<const int32> dilations,
const std::vector<int32>& strides, int64 padding_before,
int64 padding_after, int spatial_dim, int filter_spatial_dim,
ConvBackpropSpatialDimension* dim) {
dim->input_size = input_shape.at(spatial_dim);
dim->filter_size = filter_shape.at(filter_spatial_dim);
dim->output_size = output_shape.at(spatial_dim);
dim->stride = strides[spatial_dim];
dim->dilation = dilations[spatial_dim];
int64 effective_filter_size = (dim->filter_size - 1) * dim->dilation + 1;
int64 out_size = (dim->input_size + padding_before + padding_after -
effective_filter_size + dim->stride) /
dim->stride;
if (dim->output_size != out_size) {
return InvalidArgument(
"ConvBackpropExtractAndVerifyDimension: Size of out_backprop doesn't "
"match computed: actual = %ld, "
"computed = %ld, spatial_dim: %d, input: %ld, filter: %ld, output: "
"%ld, stride: %ld, dilation: %ld",
dim->output_size, out_size, spatial_dim, dim->input_size,
dim->filter_size, dim->output_size, dim->stride, dim->dilation);
}
dim->expanded_output_size = (dim->output_size - 1) * dim->stride + 1;
const auto padded_out_size = dim->input_size + effective_filter_size - 1;
dim->pad_before = effective_filter_size - 1 - padding_before;
dim->pad_after =
padded_out_size - dim->expanded_output_size - dim->pad_before;
VLOG(2) << "ConvBackpropExtractAndVerifyDimension: expanded_out = "
<< dim->expanded_output_size
<< ", effective_filter_size = " << effective_filter_size
<< ", padded_out = " << padded_out_size
<< ", pad_before = " << dim->pad_before
<< ", pad_after = " << dim->pad_after
<< ", dilation = " << dim->dilation << ", strides = " << dim->stride;
return Status::OK();
}
// Verifies that the dimensions all match, and computes sizes/padding for the
// spatial dimensions.
Status ConvBackpropComputeDimensions(
absl::string_view label, int num_spatial_dims,
absl::Span<const int64> input_shape, absl::Span<const int64> filter_shape,
absl::Span<const int64> out_backprop_shape,
absl::Span<const int32> dilations, const std::vector<int32>& strides,
Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
absl::Span<const int64> explicit_paddings) {
TensorShape input_tensor_shape, filter_tensor_shape,
out_backprop_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
return ConvBackpropComputeDimensionsV2(
label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
data_format, dims);
absl::Span<const int64> explicit_paddings,
const ConvolutionDimensionNumbers& data_format,
ConvBackpropDimensions* dims) {
// The + 2 in the following line is for the batch and feature dimensions.
const int num_dims = num_spatial_dims + 2;
if (input_shape.size() != num_dims) {
return InvalidArgument("%s: input must be %d-dimensional", label, num_dims);
}
if (filter_shape.size() != num_dims) {
return InvalidArgument("%s: filter must be %d-dimensional", label,
num_dims);
}
if (out_backprop_shape.size() != num_dims) {
return InvalidArgument("%s: out_backprop must be %d-dimensional", label,
num_dims);
}
int batch_dim = data_format.input_batch_dimension();
dims->batch_size = input_shape.at(batch_dim);
if (dims->batch_size != out_backprop_shape.at(batch_dim)) {
return InvalidArgument(
"%s: input and out_backprop must have the same batch size, input "
"batch: %ld outbackprop batch: %ld batch_dim: %d",
label, dims->batch_size, out_backprop_shape.at(batch_dim), batch_dim);
}
int feature_dim = data_format.input_feature_dimension();
dims->in_depth = input_shape.at(feature_dim);
// The input and output feature dimensions are the second last and last
// dimensions of the filter Tensor.
VLOG(2) << "input vs filter_in depth " << dims->in_depth << " "
<< filter_shape.at(num_dims - 2);
if (dims->in_depth % filter_shape.at(num_dims - 2)) {
return InvalidArgument(
"%s: input depth must be evenly divisible by filter depth", label);
}
dims->out_depth = filter_shape.at(num_dims - 1);
if (dims->out_depth != out_backprop_shape.at(feature_dim)) {
return InvalidArgument(
"%s: filter and out_backprop must have the same out_depth", label);
}
dims->spatial_dims.resize(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
int image_dim = data_format.input_spatial_dimensions(i);
int64 padding_before = -1, padding_after = -1;
padding_before = explicit_paddings[2 * image_dim];
padding_after = explicit_paddings[2 * image_dim + 1];
TF_RETURN_IF_ERROR(ConvBackpropExtractAndVerifyDimension(
input_shape, filter_shape, out_backprop_shape, dilations, strides,
padding_before, padding_after, image_dim, i, &dims->spatial_dims[i]));
}
return Status::OK();
}
} // anonymous namespace
xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
bool depthwise,
OpKernelConstruction* ctx) {
ConvOpAttrs attrs;
attrs.num_spatial_dims = num_spatial_dims;
attrs.depthwise = depthwise;
TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
if (attrs.padding == EXPLICIT) {
TF_RETURN_IF_ERROR(
ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
}
string data_format;
TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
if (!FormatFromString(data_format, &attrs.data_format)) {
return errors::InvalidArgument("Invalid data format: ", data_format);
}
return attrs;
}
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
xla::XlaOp conv_input,
xla::XlaOp filter,
const ConvOpAttrs& attrs) {
StatusOr<XlaOp> MakeXlaForwardConvOp(absl::string_view /*type_string*/,
XlaOp conv_input, XlaOp filter,
const ConvOpAttrs& attrs,
const PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = conv_input.builder();
TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(conv_input));
// Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
TF_ASSIGN_OR_RETURN(Shape filter_shape, builder->GetShape(filter));
// For 2D convolution, there should be 4 dimensions.
int num_dims = attrs.num_spatial_dims + 2;
if (input_shape.dimensions_size() != num_dims) {
return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
return InvalidArgument("input must be %d-dimensional: %s", num_dims,
input_shape.DebugString());
}
if (filter_shape.dimensions_size() != num_dims) {
return errors::InvalidArgument(
"filter must be ", num_dims,
"-dimensional: ", filter_shape.DebugString());
return InvalidArgument("filter must be %d-dimensional: %s", num_dims,
filter_shape.DebugString());
}
// The last two dimensions of the filter are the input and output shapes.
int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
int batch_dim = attrs.data_format.input_batch_dimension();
int feature_dim = attrs.data_format.input_feature_dimension();
int64 filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
out_depth = filter_shape.dimensions(attrs.num_spatial_dims + 1),
@ -326,22 +395,22 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
// The 'C' dimension for input is in_depth.
// It must be a multiple of the filter's in_depth.
if (in_depth % filter_in_depth != 0) {
return errors::InvalidArgument(
"Depth of input must be a multiple of depth of filter: ", in_depth,
" vs ", filter_in_depth);
return InvalidArgument(
"Depth of input must be a multiple of depth of filter: %d vs %d",
in_depth, filter_in_depth);
}
int64 feature_group_count = in_depth / filter_in_depth;
if (out_depth % feature_group_count != 0) {
return errors::InvalidArgument(
"Depth of output must be a multiple of the number of groups: ",
out_depth, " vs ", feature_group_count);
return InvalidArgument(
"Depth of output must be a multiple of the number of groups: %d vs %d",
out_depth, feature_group_count);
}
if (attrs.depthwise) {
filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
}
xla::ConvolutionDimensionNumbers dims;
ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides(attrs.num_spatial_dims);
std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
@ -355,64 +424,57 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
const int64 dim = attrs.data_format.input_spatial_dimensions(i);
dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
dims.add_output_spatial_dimensions(dim);
window_strides[i] = attrs.strides.at(dim);
rhs_dilation[i] = attrs.dilations.at(dim);
if (attrs.padding == EXPLICIT) {
padding[i] = {attrs.explicit_paddings.at(dim * 2),
attrs.explicit_paddings.at(dim * 2 + 1)};
}
int64 unused_output_size;
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
input_shape.dimensions(dim), filter_shape.dimensions(i),
rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
&padding[i].first, &padding[i].second));
}
return xla::ConvGeneralDilated(
return ConvGeneralDilated(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count);
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
StatusOr<XlaOp> MakeXlaBackpropInputConvOp(
absl::string_view type_string, const Shape& input_shape, XlaOp filter,
XlaOp out_backprop, const ConvOpAttrs& attrs,
const PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
int num_dims = attrs.num_spatial_dims + 2;
int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
int batch_dim = attrs.data_format.input_batch_dimension();
int feature_dim = attrs.data_format.input_feature_dimension();
auto* builder = filter.builder();
TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
TF_ASSIGN_OR_RETURN(Shape filter_shape, builder->GetShape(filter));
TF_ASSIGN_OR_RETURN(Shape out_backprop_shape,
builder->GetShape(out_backprop));
int64 in_depth = input_shape.dimensions(feature_dim),
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
feature_group_count = in_depth / filter_in_depth;
xla::Shape expanded_filter_shape =
Shape expanded_filter_shape =
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_ops.cc.
ConvBackpropDimensions dims;
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
attrs.data_format, &dims, attrs.explicit_paddings));
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensions(
type_string, attrs.num_spatial_dims, input_shape.dimensions(),
expanded_filter_shape.dimensions(), out_backprop_shape.dimensions(),
attrs.dilations, attrs.strides, attrs.explicit_paddings,
attrs.data_format, &dims));
// The input gradients are computed by a convolution of the output
// gradients and the filter, with some appropriate padding. See the
// comment at the top of conv_grad_ops.h for details.
xla::ConvolutionDimensionNumbers dnums;
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(batch_dim);
dnums.set_output_batch_dimension(batch_dim);
dnums.set_input_feature_dimension(feature_dim);
@ -429,7 +491,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
std::vector<int64> ones(attrs.num_spatial_dims, 1);
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
int64 dim = attrs.data_format.input_spatial_dimensions(i);
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(i);
dnums.add_output_spatial_dimensions(dim);
@ -446,41 +508,36 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
filter, filter_shape, feature_group_count, attrs.num_spatial_dims);
}
// Mirror the filter in the spatial dimensions.
filter = xla::Rev(filter, kernel_spatial_dims);
filter = Rev(filter, kernel_spatial_dims);
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
return xla::ConvGeneralDilated(
return ConvGeneralDilated(
out_backprop, filter, /*window_strides=*/ones, padding, lhs_dilation,
rhs_dilation, dnums,
/*feature_group_count=*/
attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
filter_shape.dimensions(attrs.num_spatial_dims + 1)
: feature_group_count);
: feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,
const ConvOpAttrs& attrs) {
StatusOr<XlaOp> MakeXlaBackpropFilterConvOp(
absl::string_view type_string, XlaOp activations, const Shape& filter_shape,
XlaOp out_backprop, const ConvOpAttrs& attrs,
const PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = activations.builder();
TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
builder->GetShape(activations));
TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
builder->GetShape(gradients));
xla::XlaOp filter_backprop;
TF_ASSIGN_OR_RETURN(Shape activations_shape, builder->GetShape(activations));
TF_ASSIGN_OR_RETURN(Shape out_backprop_shape,
builder->GetShape(out_backprop));
XlaOp filter_backprop;
xla::Shape input_shape = activations_shape;
xla::Shape output_shape = out_backprop_shape;
Shape input_shape = activations_shape;
Shape output_shape = out_backprop_shape;
TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
const xla::Shape expanded_filter_shape =
const Shape expanded_filter_shape =
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_ops.cc.
@ -488,18 +545,19 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
// See the comment at the top of conv_grad_ops.h for details.
xla::ConvolutionDimensionNumbers dnums;
ConvolutionDimensionNumbers dnums;
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
type_string, attrs.num_spatial_dims, activations_shape,
expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensions(
type_string, attrs.num_spatial_dims, activations_shape.dimensions(),
expanded_filter_shape.dimensions(), out_backprop_shape.dimensions(),
attrs.dilations, attrs.strides, attrs.explicit_paddings,
attrs.data_format, &dims));
// Obtain some useful dimensions:
// The last two dimensions of the filter are the input and output shapes.
int num_dims = attrs.num_spatial_dims + 2;
int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
int n_dim = attrs.data_format.input_batch_dimension();
int c_dim = attrs.data_format.input_feature_dimension();
int64 in_depth = input_shape.dimensions(c_dim),
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
feature_group_count = in_depth / filter_in_depth;
@ -518,7 +576,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
// In the case of depthwise convolution with no multiplier,
// the computation can be done by the batch_group_count parameter.
bool use_batch_group_count =
filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
filter_shape.dimensions(num_dims - 1) == 1 && attrs.depthwise;
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
@ -550,7 +608,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
}
for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
int64 dim = attrs.data_format.input_spatial_dimensions(i);
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);
rhs_dilation[i] = dims.spatial_dims[i].stride;
@ -591,11 +649,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
// In addition, if the padded input size is smaller than the input size,
// we need to ignore some training elements of the input. We do this by
// applying negative padding on the right/bottom.
const int64 pad_before = attrs.padding == Padding::EXPLICIT
? attrs.explicit_paddings[2 * dim]
: attrs.padding == Padding::SAME
? std::max<int64>(pad_total / 2, 0)
: 0;
const int64 pad_before = attrs.explicit_paddings[2 * dim];
padding[i] = {pad_before, pad_total - pad_before};
}
@ -608,11 +662,12 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
// This is done by specifying the window dilation factors in the
// convolution HLO below.
filter_backprop = xla::ConvGeneralDilated(
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
filter_backprop = ConvGeneralDilated(
activations, out_backprop, window_strides, padding, /*lhs_dilation=*/ones,
rhs_dilation, dnums,
/*feature_group_count=*/feature_group_count,
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1,
precision_config);
if (!use_batch_group_count && attrs.depthwise) {
filter_backprop = ContractFilterForDepthwiseBackprop(
@ -622,4 +677,4 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
return filter_backprop;
}
} // namespace tensorflow
} // namespace xla

View File

@ -0,0 +1,63 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_OP_HELPERS_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_OP_HELPERS_H_
#include <vector>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
// This header exposes utilities for translating TensorFlow convolution ops into
// XLA ops.
namespace xla {
// ConvOpAttrs contains all of the metadata necessary to specify an XLA
// convolution.
struct ConvOpAttrs {
bool depthwise;
int num_spatial_dims;
std::vector<int32> dilations;
std::vector<int32> strides;
std::vector<int64> explicit_paddings;
ConvolutionDimensionNumbers data_format;
};
// Computes the convolution with the given input, filter and attributes. Errors
// returned by this function and the ones below are tagged with "type_string",
// which is the name of the TensorFlow operator using them.
StatusOr<XlaOp> MakeXlaForwardConvOp(
absl::string_view type_string, XlaOp conv_input, XlaOp filter,
const ConvOpAttrs& attrs,
const PrecisionConfig* precision_config = nullptr);
// Computes the gradient with respect to the input, given the output gradient
// and the filter.
StatusOr<XlaOp> MakeXlaBackpropInputConvOp(
absl::string_view type_string, const Shape& input_shape, XlaOp filter,
XlaOp out_backprop, const ConvOpAttrs& attrs,
const PrecisionConfig* precision_config = nullptr);
// Computes the gradient with respect to the filter, given the output gradient
// and the activations.
StatusOr<XlaOp> MakeXlaBackpropFilterConvOp(
absl::string_view type_string, XlaOp activations, const Shape& filter_shape,
XlaOp out_backprop, const ConvOpAttrs& attrs,
const PrecisionConfig* precision_config = nullptr);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_OP_HELPERS_H_

View File

@ -109,7 +109,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
}
// Converts a uint64 to two uint32s.
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) {
XlaBuilder* builder = u64.builder();
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
XlaOp fst = ConvertElementType(u64, U32);
@ -118,7 +118,7 @@ ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
}
// Converts two uint32s to a uint64.
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
XlaOp Uint32sToUint64(std::array<XlaOp, 2> u32s) {
XlaBuilder* builder = u32s[0].builder();
return ConvertElementType(u32s[0], U64) |
ShiftLeft(ConvertElementType(u32s[1], U64),
@ -168,6 +168,157 @@ RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
inputs_state.second};
}
// The key of the Philox random number generator.
using Philox4x32Key = std::array<XlaOp, 2>;
// The internal state of the Philox random number generator.
using Philox4x32State = std::array<XlaOp, 4>;
// Computes the Philox4x32 algorithm using 10 rounds.
Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) {
// Constants specified by the Philox algorithm.
static const uint32 kPhiloxW32A = 0x9E3779B9;
static const uint32 kPhiloxW32B = 0xBB67AE85;
static const uint32 kPhiloxM4x32A = 0xD2511F53;
static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
struct HighLowPair {
XlaOp high;
XlaOp low;
};
// Compute the high and low words from multiplying two 32-bit integers.
auto mul_hi_low = [](XlaOp x, uint32 k) {
auto product =
ConvertElementType(x, U64) * ConstantR0<uint64>(x.builder(), k);
auto low = ConvertElementType(product, U32);
auto high =
ConvertElementType(product >> ConstantR0<uint64>(x.builder(), 32), U32);
return HighLowPair{high, low};
};
// Perform a single round of the Philox algorithm.
auto philox_round = [&](Philox4x32State x, Philox4x32Key key) {
auto product0 = mul_hi_low(x[0], kPhiloxM4x32A);
auto product1 = mul_hi_low(x[2], kPhiloxM4x32B);
return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low,
product0.high ^ x[3] ^ key[1], product0.low};
};
// Update the key after a round of Philox algorithm.
auto raise_key = [](Philox4x32Key key) {
XlaBuilder* builder = key[0].builder();
return Philox4x32Key{key[0] + ConstantR0<uint32>(builder, kPhiloxW32A),
key[1] + ConstantR0<uint32>(builder, kPhiloxW32B)};
};
static const int kNumRounds = 10;
for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) {
state = philox_round(state, key);
}
return state;
}
// Scrambles the input key so that users don't need to worry about which part
// of the key needs to be strong.
std::pair<Philox4x32State, Philox4x32Key> GeneratePhiloxInternalStateAndKey(
Philox4x32Key key) {
XlaBuilder* builder = key[0].builder();
XlaOp key0 = ConvertElementType(key[0], U64);
XlaOp key1 = ConvertElementType(key[1], U64);
Philox4x32State state = {
ConvertElementType(key0, U32),
ConvertElementType(key0 >> ScalarLike(key0, 32), U32),
ConvertElementType(key1, U32),
ConvertElementType(key1 >> ScalarLike(key1, 32), U32),
};
key = {ConstantR0<uint32>(builder, 0x3ec8f720),
ConstantR0<uint32>(builder, 0x02461e29)};
state = Philox4x32(state, key);
XlaOp zero = ConstantR0<uint32>(builder, 0);
return {Philox4x32State{zero, zero, state[2], state[3]},
Philox4x32Key{state[0], state[1]}};
}
// Adds the integers [0, 1, ..., n) to 'state', treating 'state' as a 4 U32s, to
// compute n states for generating n random numbers.
Philox4x32State GetPhiloxGeneratorInputState(Philox4x32State state, int64 n) {
XlaBuilder* builder = state[0].builder();
XlaOp iota = Iota(builder, U64, n);
XlaOp state_low = Uint32sToUint64({state[0], state[1]});
XlaOp new_state_low = state_low + iota;
std::array<XlaOp, 2> new_state_low_32s = Uint64ToUint32s(new_state_low);
XlaOp one = ConstantR0<uint64>(builder, 1);
XlaOp state_high = Uint32sToUint64({state[2], state[3]});
XlaOp new_state_high =
Select(Lt(new_state_low, state_low), Broadcast(state_high + one, {n}),
Broadcast(state_high, {n}));
std::array<XlaOp, 2> new_state_high_32s = Uint64ToUint32s(new_state_high);
return {new_state_low_32s[0], new_state_low_32s[1], new_state_high_32s[0],
new_state_high_32s[1]};
}
// Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox
// numbers are generated in the unit of 128bits.
Philox4x32State GeneratePhiloxBits(int64 num_elems, Philox4x32Key key) {
Philox4x32State state;
std::tie(state, key) = GeneratePhiloxInternalStateAndKey(key);
const int64 num_vector4 = CeilOfRatio<int64>(num_elems, 4);
return Philox4x32(GetPhiloxGeneratorInputState(state, num_vector4), key);
}
// Generates an array of primitive type U32 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
const Shape& shape) {
XlaBuilder* builder = op_key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, num_elems);
Philox4x32Key key = Uint64ToUint32s(op_key + initial_state);
Philox4x32State state = GeneratePhiloxBits(num_elems, key);
XlaOp numbers = ConcatInDim(builder, {state[0], state[1], state[2], state[3]},
/*dimension=*/0);
numbers = Slice(numbers, /*start_indices=*/{0},
/*limit_indices=*/{num_elems},
/*strides=*/{1});
return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
}
// Generates an array of primitive type U64 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
const Shape& shape) {
XlaBuilder* builder = op_key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, num_elems);
Philox4x32Key key = Uint64ToUint32s(op_key + initial_state);
Philox4x32State state32 = GeneratePhiloxBits(num_elems * 2, key);
auto convert_to_64 = [&](XlaOp v0, XlaOp v1) {
return ConvertElementType(v0, U64) |
ShiftLeft(ConvertElementType(v1, U64),
ConstantR0WithType(builder, U64, 32));
};
std::array<XlaOp, 2> state64;
state64[0] = convert_to_64(state32[0], state32[1]);
state64[1] = convert_to_64(state32[2], state32[3]);
XlaOp numbers = ConcatInDim(builder, {state64[0], state64[1]},
/*dimension=*/0);
numbers = Slice(numbers, /*start_indices=*/{0},
/*limit_indices=*/{num_elems},
/*strides=*/{1});
return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
}
XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = bits.builder();
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
@ -200,8 +351,17 @@ XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
BitcastConvertType(dist - dist_div_2, type);
}
XlaOp UniformToNormalUsingSqrtErfInv(XlaOp uniform) {
return ScalarLike(uniform, std::sqrt(2.0)) * ErfInv(uniform);
// Implements the Box-Muller transform, which converts random floats in the
// range of [0, 1] from uniform distribution to normal distribution with mean 0
// and variance 1. For more detail on the Box-Muller transform, see
// http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
// Do not send a really small number to log().
XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
return {Sin(v1) * u2, Cos(v1) * u2};
}
} // namespace
@ -221,7 +381,27 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
return {key.builder()->ReportError(Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by ThreeFryBitGenerator; got %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape) {
PrimitiveType type = shape.element_type();
switch (type) {
case F32:
case U32:
case S32:
return PhiloxRngBit32(key, initial_state, shape);
case U64:
case S64:
return PhiloxRngBit64(key, initial_state, shape);
default:
return {key.builder()->ReportError(Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by ThreeFryBitGenerator; got %s",
primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
@ -260,11 +440,26 @@ RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
const Shape& shape) {
DCHECK_EQ(shape.element_type(), F32);
XlaBuilder* builder = key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
const int64 num_pairs = CeilOfRatio<int64>(num_elems, 2);
RngOutput bits_state = UniformF32Distribution(
key, initial_state, bit_generator,
ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
ConstantR0<float>(builder, 1.0), shape);
XlaOp normal = UniformToNormalUsingSqrtErfInv(bits_state.value);
key, initial_state, bit_generator, ConstantR0<float>(builder, 0.0),
ConstantR0<float>(builder, 1.0),
ShapeUtil::MakeShape(F32, {num_pairs * 2}));
// Separate the bits into two groups to perform the Box-Muller transform.
XlaOp bits_0 = Slice(bits_state.value, {0}, {num_pairs}, {1});
XlaOp bits_1 = Slice(bits_state.value, {num_pairs}, {2 * num_pairs}, {1});
std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
// Put the numbers in the two groups back to form the requested shape.
XlaOp normal = ConcatInDim(builder, {bits_0, bits_1}, /*dimension=*/0);
if (num_elems != num_pairs * 2) {
normal = Slice(normal, /*start_indices=*/{0}, /*limit_indices=*/{num_elems},
/*strides=*/{1});
}
normal = Reshape(normal, shape.dimensions());
return {normal, bits_state.state};
}

View File

@ -50,6 +50,17 @@ using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
const xla::Shape& shape);
// Implements the Philox algorithm to generate random numbers in parallel.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
//
// The paper presents a few variants of the Philox algorithm, we picked the
// 4x32_10 version of the algorithm for the following reasons:
// . 4x32 uses 32-bit multiplication which is fast on GPUs.
// . The authors recommend the 10-round variant, and TensorFlow also uses it.
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape);
// Uses the given bit generator to generate random bits and then converts the
// random bits to random numbers of uniform distribution in the given range.
// Returns the random numbers and the state of the random number generator.

Some files were not shown because too many files have changed in this diff Show More