Merge remote-tracking branch origin/upstream/master

This commit is contained in:
Måns Nilsson 2020-11-09 12:58:26 +01:00
commit 0dae721ffb
2883 changed files with 84460 additions and 38230 deletions

View File

@ -49,7 +49,6 @@
# rocm: Build with AMD GPU support (rocm). # rocm: Build with AMD GPU support (rocm).
# mkl: Enable full mkl support. # mkl: Enable full mkl support.
# tensorrt: Enable Tensorrt support. # tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
# numa: Enable numa using hwloc. # numa: Enable numa using hwloc.
# noaws: Disable AWS S3 storage support # noaws: Disable AWS S3 storage support
# nogcp: Disable GCS support. # nogcp: Disable GCS support.
@ -159,6 +158,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build. # environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_openmp=true
build:mkl -c opt build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool. # config to build OneDNN backend with a user specified threadpool.
@ -172,6 +172,7 @@ build:mkl_threadpool -c opt
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only --define=build_with_openmp=true
build:mkl_opensource_only -c opt build:mkl_opensource_only -c opt
# Config setting to build with oneDNN for Arm. # Config setting to build with oneDNN for Arm.
@ -218,7 +219,6 @@ build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --action_env TF_NEED_ROCM=1 build:rocm --action_env TF_NEED_ROCM=1
# Options extracted from configure script # Options extracted from configure script
build:ngraph --define=with_ngraph_support=true
build:numa --define=with_numa_support=true build:numa --define=with_numa_support=true
# Options to disable default on features # Options to disable default on features
@ -283,7 +283,7 @@ build:ios --copt=-w
build:linux --copt=-w build:linux --copt=-w
build:linux --host_copt=-w build:linux --host_copt=-w
build:macos --copt=-w build:macos --copt=-w
build:windows --copt=/w build:windows --copt=/W0
# Tensorflow uses M_* math constants that only get defined by MSVC headers if # Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined. # _USE_MATH_DEFINES is defined.
@ -294,9 +294,11 @@ build:windows --host_copt=/D_USE_MATH_DEFINES
build:linux --define=PREFIX=/usr build:linux --define=PREFIX=/usr
build:linux --define=LIBDIR=$(PREFIX)/lib build:linux --define=LIBDIR=$(PREFIX)/lib
build:linux --define=INCLUDEDIR=$(PREFIX)/include build:linux --define=INCLUDEDIR=$(PREFIX)/include
build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
build:macos --define=PREFIX=/usr build:macos --define=PREFIX=/usr
build:macos --define=LIBDIR=$(PREFIX)/lib build:macos --define=LIBDIR=$(PREFIX)/lib
build:macos --define=INCLUDEDIR=$(PREFIX)/include build:macos --define=INCLUDEDIR=$(PREFIX)/include
build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows. # TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode. # By default, build TF in C++ 14 mode.

View File

@ -132,12 +132,20 @@ Build Type
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) **Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) **Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux aarch64 CPU** Nightly <br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) **Linux aarch64 CPU** Nightly (Linaro)<br> Python 3.8 | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | [Nightly](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
**Linux aarch64 CPU** Stable Release | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show) **Linux aarch64 CPU** Stable Release (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | Release [1.x & 2.x](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
**Linux aarch64 CPU** Nightly (OpenLab)<br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
**Linux aarch64 CPU** Stable Release (OpenLab) | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) **Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
### Community Supported Containers
Container Type | Status | Artifacts
----------------------------------------------------------------- | ------ | ---------
**TensorFlow aarch64 Neoverse-N1 CPU** Stable (Linaro)<br> Debian | Static | Release [2.3](https://hub.docker.com/r/linaro/tensorflow-arm-neoverse-n1)
## Resources ## Resources
* [TensorFlow.org](https://www.tensorflow.org) * [TensorFlow.org](https://www.tensorflow.org)
@ -150,6 +158,7 @@ Build Type
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) * [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the * [TensorFlow Chat Room on StackOverflow (not actively monitored by the
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow) TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
* [TensorFlow Blog](https://blog.tensorflow.org) * [TensorFlow Blog](https://blog.tensorflow.org)

View File

@ -1,3 +1,61 @@
# Release 2.5.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
## Breaking Changes
* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
## Known Caveats
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES).>
* <ADDING/BUMPING DEPENDENCIES SHOULD GO HERE>
* <KNWON LACK OF SUPPORT ON SOME PLATFORM, SHOULD GO HERE>
## Major Features and Improvements
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* TPU embedding support
* Added `profile_data_directory` to `EmbeddingConfigSpec` in
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
gathered at runtime to be used in embedding layer partitioning decisions.
## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* `tf.keras`:
* Improvements to Keras preprocessing layers:
* Discretization combiner implemented, with additional arg `epsilon`.
* `tf.data`:
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
to control how external state should be handled during dataset
serialization or iterator checkpointing.
* XLA compilation:
* `tf.function(experimental_compile=True)` has become a stable API,
renamed `tf.function(jit_compile=True)`.
* `tf.lite`:
* NNAPI
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with
`tf.GradientTape` inside a `tf.function`.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.4.0 # Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES> <INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
@ -6,6 +64,15 @@
* <DOCUMENT BREAKING CHANGES HERE> * <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES> * <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
matmuls and convolutions, due to the use of
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
bits of precision. This is unlikely to cause issues in practice for deep
learning models. In some cases, TensorFloat-32 is also used for complex64 ops.
TensorFloat-32 can be disabled by running
`config.experimental.enable_tensor_float_32_execution(False)`. The "Major
Features and Improvements" section has more details.
* The byte layout for string tensors across the C-API has been updated to match * The byte layout for string tensors across the C-API has been updated to match
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s. TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and * C-API functions `TF_StringDecode`, `TF_StringEncode`, and
@ -54,6 +121,42 @@
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please * `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
use `tf.data.Dataset.from_tensor_slices` instead. use `tf.data.Dataset.from_tensor_slices` instead.
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
`tf.distribute.StrategyExtended.batch_reduce_to`,
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
`tf.distribute.experimental.CollectiveHints` is renamed
`tf.distribute.experimental.CommunicationOptions`.
`tf.distribute.experimental.CollectiveCommunication` is renamed
`tf.distribute.experimental.CommunicationImplementation`.
* `tf.keras.mixed_precision.experimental`:
* `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
dtype it will be casted to.
* When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
float16 or bfloat16 tensor instead of a float32 tensor.
* The property
`tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now
a tensor, not a `LossScale` object. This means to get a loss scale of a
`LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead
of `opt.loss_scale()`.
* The property `should_cast_variables` has been removed from
`tf.keras.mixed_precision.experimental.Policy`
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the
`DynamicLossScale`'s multiplier must be 2.
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of
being reused. This means modifying the weights of the `DynamicLossScale`
will no longer affect the weights of the LossScaleOptimizer, and vice versa.
* The global policy can no longer be set to a non-floating point policy in
`tf.keras.mixed_precision.experimental.set_policy`
* In `Layer.call`, `AutoCastVariable`s will no longer be casted within
`MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a
thread local variable is used to determine whether `AutoCastVariable`s are
casted, and those two functions run with a different thread. Note this only
applies if one of these two functions is called within `Layer.call`; if one
of those two functions calls `Layer.call`, `AutoCastVariable`s will still be
casted.
## Known Caveats ## Known Caveats
@ -65,9 +168,40 @@
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER> * <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy. * A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. * A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
* Support for
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a
math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as
matrix multiplications and convolutions, to run much faster on Ampere GPUs but
with reduced precision. This reduced precision has not been found to effect
convergence quality of deep learning models in practice. TensorFloat-32 is
enabled by default, but can be disabled with
`tf.config.experimental.enable_tensor_float_32_execution`.
* `tf.distribute`: * `tf.distribute`:
* `MultiWorkerMirroredStrategy` is graduated out of experimental.
* Peer failure will no longer cause the cluster to hang.
* Major issues with saving are fixed.
* See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
* The `tf.keras.mixed_precision` API has been made non-experimental. The major
changes to the new non-experimental API are:
* `tf.keras.mixed_precision.Policy` no longer takes in a
`tf.mixed_precision.experimental.LossScale` in the constructor, and no
longer has a `LossScale` associated with it. Instead, `Model.compile` will
automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic
loss scaling if `Policy.name` is "mixed_float16".
* `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in
different arguments. In particular, it no longer takes in a `LossScale`, and
there is no longer a `LossScale` associated with the `LossScaleOptimizer`.
Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss
scaling. See the documentation of
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on
the differences between the experimental `LossScaleOptimizer` and the new
non-experimental `LossScaleOptimizer`.
* `tf.mixed_precision.experimental.LossScale` and its subclasses are
deprecated, as all of its functionality now exists within
`tf.keras.mixed_precision.LossScaleOptimizer`
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
@ -117,6 +251,10 @@
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
* Fixes a segfault in `tf.quantization.quantize_and_dequantize`
([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
* Fixes an undefined behavior float cast causing a crash
([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
* TF Core: * TF Core:
* `tf.types.experimental.TensorLike` is a new `Union` type that can be * `tf.types.experimental.TensorLike` is a new `Union` type that can be
used as type annotation for variables representing a Tensor or a value used as type annotation for variables representing a Tensor or a value
@ -139,6 +277,7 @@
* Added `tf.config.experimental.get_memory_usage` to return total memory * Added `tf.config.experimental.get_memory_usage` to return total memory
usage of the device. usage of the device.
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`. * Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
* `tf.data`: * `tf.data`:
* tf.data service: * tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and * Added new `tf.data.experimental.service.register_dataset` and
@ -183,7 +322,16 @@
how many times the function is called, and independent of global seed how many times the function is called, and independent of global seed
settings. settings.
* `tf.distribute`: * `tf.distribute`:
* <ADD RELEASE NOTES HERE> * (Experimental) Parameter server training:
* Replaced the existing
`tf.distribute.experimental.ParameterServerStrategy` symbol with
a new class that is for parameter server training in TF2. Usage with
the old symbol, usually with Estimator, should be replaced with
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
* Added `tf.distribute.experimental.coordinator.*` namespace,
including the main API `ClusterCoordinator` for coordinating the
training cluster, the related data structure `RemoteValue`
and `PerWorkerValue`.
* `tf.keras`: * `tf.keras`:
* Improvements from the functional API refactoring: * Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global * Functional model construction does not need to maintain a global
@ -218,6 +366,8 @@
* Improvements to Keras preprocessing layers: * Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an * TextVectorization can now accept a vocabulary list or file as an
init arg. init arg.
* TextVectorization, StringLookup, and IntegerLookup can now accept a
vocabulary file via the `set_vocab_from_file` method.
* Normalization can now accept mean and variance values as init args. * Normalization can now accept mean and variance values as init args.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now * In `Attention` and `AdditiveAttention` layers, the `call()` method now
accepts a `return_attention_scores` argument. When set to accepts a `return_attention_scores` argument. When set to
@ -239,10 +389,14 @@
True, the function may use type annotations to optimize the tracing True, the function may use type annotations to optimize the tracing
performance. performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if * AutoGraph now allows creating new symbols inside a TensorFlow loop, if
the values of these symbols at an iteration does not depend on the the values of these symbols at an iteration does not depend on the
previous iteration. These types of loops must run at least one previous iteration. These types of loops must run at least one
iteration, and will raise a runtime error otherwise. iteration, and will raise a runtime error otherwise.
* Variables contained in `tf.Module`s that are set as attributes of
custom Keras `Layer`s and `Model`s are now tracked in
the properties `layer.trainable_variables` and
`layer.non_trainable_variables`.
Example: Example:
@ -279,6 +433,7 @@
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first * `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty. string to be joined is empty.
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* `tf.random`: * `tf.random`:
@ -333,6 +488,12 @@
didn't have the keys sorted, the keys and values were not being printed didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping. in accordance with their correct mapping.
* `TensorRT`
* We now issue a warning when the `session_config` parameter for the TF1
converter is used or the `rewrite_config_template` field in the TF2
converter parameter object is used.
* Other: * Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist" * We have replaced uses of "whitelist" and "blacklist" with "allowlist"
@ -341,6 +502,8 @@
context. context.
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us * Add `tf.config.experimental.mlir_bridge_rollout` which will help us
rollout the new MLIR TPU bridge. rollout the new MLIR TPU bridge.
* Added `tf.experimental.register_filesystem_plugin` to load modular
filesystem plugins from Python
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
## Thanks to our Contributors ## Thanks to our Contributors
@ -713,6 +876,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights. * Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights.
* Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights. * Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights.
* Mutable tables now restore checkpointed values when loaded from SavedModel. * Mutable tables now restore checkpointed values when loaded from SavedModel.
* The user object metadata field in the SavedModel proto has been deprecated as part of the updates to Keras SavedModel. Keras was the only consumer of this field prior to the update.
* GPU * GPU
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities. * TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
* Remove environmental variable `TF_USE_CUDNN`. * Remove environmental variable `TF_USE_CUDNN`.
@ -741,6 +905,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses. * Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`. * Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. * Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
* Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices.
### `tf.keras`: ### `tf.keras`:
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs. * Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.

View File

@ -113,26 +113,10 @@ http_archive(
# Required for dependency @com_github_grpc_grpc # Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps() grpc_deps()
load( load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
"@build_bazel_rules_apple//apple:repositories.bzl", grpc_extra_deps()
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")
load("//third_party/googleapis:repository_rules.bzl", "config_googleapis") load("//third_party/googleapis:repository_rules.bzl", "config_googleapis")

View File

@ -1163,12 +1163,9 @@ def set_system_libs_flag(environ_cp):
syslibs = ','.join(sorted(syslibs.split())) syslibs = ','.join(sorted(syslibs.split()))
write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
if 'PREFIX' in environ_cp: for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'):
write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) if varname in environ_cp:
if 'LIBDIR' in environ_cp: write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname]))
write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR'])
if 'INCLUDEDIR' in environ_cp:
write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR'])
def is_reduced_optimize_huge_functions_available(environ_cp): def is_reduced_optimize_huge_functions_available(environ_cp):
@ -1487,7 +1484,6 @@ def main():
config_info_line('mkl', 'Build with MKL support.') config_info_line('mkl', 'Build with MKL support.')
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.') config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
config_info_line('monolithic', 'Config for mostly static monolithic build.') config_info_line('monolithic', 'Config for mostly static monolithic build.')
config_info_line('ngraph', 'Build with Intel nGraph support.')
config_info_line('numa', 'Build with NUMA support.') config_info_line('numa', 'Build with NUMA support.')
config_info_line( config_info_line(
'dynamic_kernels', 'dynamic_kernels',

View File

@ -3,6 +3,7 @@
# learning applications. # learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load( load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
@ -22,10 +23,6 @@ load(
"//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
"TENSORFLOW_API_INIT_FILES_V1", # @unused "TENSORFLOW_API_INIT_FILES_V1", # @unused
) )
load(
"//third_party/ngraph:build_defs.bzl",
"if_ngraph",
)
load( load(
"//third_party/mkl:build_defs.bzl", "//third_party/mkl:build_defs.bzl",
"if_mkl_ml", "if_mkl_ml",
@ -238,6 +235,12 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "linux_mips64",
values = {"cpu": "mips64"},
visibility = ["//visibility:public"],
)
config_setting( config_setting(
name = "debug", name = "debug",
values = { values = {
@ -465,14 +468,6 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# This flag is set from the configure step when the user selects with nGraph option.
# By default it should be false
config_setting(
name = "with_ngraph_support",
values = {"define": "with_ngraph_support=true"},
visibility = ["//visibility:public"],
)
# This flag specifies whether TensorFlow 2.0 API should be built instead # This flag specifies whether TensorFlow 2.0 API should be built instead
# of 1.* API. Note that TensorFlow 2.0 API is currently under development. # of 1.* API. Note that TensorFlow 2.0 API is currently under development.
config_setting( config_setting(
@ -563,6 +558,33 @@ selects.config_setting_group(
], ],
) )
# 'enable_registration_v2' opts-in to a different implementation of op and
# kernel registration - REGISTER_OP, REGISTER_KERNEL_BUILDER, etc.
#
# This setting is currently experimental. The 'v2' implementation does _not_
# correspond to a particular, finalized design; rather, it relates to
# developing one.
#
# The current aim of the 'v2' implementation is to allow 'unused' ops and
# kernels to be discarded by the linker (to the benefit of binary size).
bool_flag(
name = "enable_registration_v2",
build_setting_default = False,
visibility = ["//visibility:public"],
)
config_setting(
name = "registration_v1",
flag_values = {":enable_registration_v2": "False"},
visibility = ["//visibility:public"],
)
config_setting(
name = "registration_v2",
flag_values = {":enable_registration_v2": "True"},
visibility = ["//visibility:public"],
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! # DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides. # Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs. # If you need functionality that is not exposed, we will work with you to expand our public APIs.
@ -574,10 +596,7 @@ package_group(
], ],
) )
package_group( package_group(name = "ndarray_tensor_allow_list")
name = "ndarray_tensor_allow_list",
packages = ["//learning/pathways/..."],
)
# Packages that use private types symbols, until they are exported. # Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove. # TODO(b/154650521) Remove.
@ -608,7 +627,7 @@ bzl_library(
"//tensorflow/core/platform/default:cuda_build_defs_bzl", "//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl", "//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl", "//third_party/mkl_dnn:build_defs_bzl",
"//third_party/ngraph:build_defs_bzl", "@bazel_skylib//rules:common_settings",
"@local_config_cuda//cuda:build_defs_bzl", "@local_config_cuda//cuda:build_defs_bzl",
"@local_config_rocm//rocm:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl",
"@local_config_tensorrt//:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl",
@ -709,6 +728,9 @@ tf_cc_shared_object(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/c:kernels_hdrs",
"//tensorflow/c:ops_hdrs",
"//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core/common_runtime:core_cpu_impl",
"//tensorflow/core:framework_internal_impl", "//tensorflow/core:framework_internal_impl",
@ -812,7 +834,7 @@ tf_cc_shared_object(
"//tensorflow/cc:scope", "//tensorflow/cc:scope",
"//tensorflow/cc/profiler", "//tensorflow/cc/profiler",
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
] + if_ngraph(["@ngraph_tf//:ngraph_tf"]), ],
) )
# ** Targets for Windows build (start) ** # ** Targets for Windows build (start) **

View File

@ -202,6 +202,7 @@ tf_cuda_library(
":tf_status", ":tf_status",
":tf_tensor", ":tf_tensor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"//tensorflow/c/experimental/filesystem:modular_filesystem",
"//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients", "//tensorflow/cc:gradients",
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
@ -511,6 +512,18 @@ tf_cuda_library(
], ],
) )
cc_library(
name = "kernels_hdrs",
hdrs = ["kernels.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_internal",
":tf_datatype",
":tf_status",
":tf_tensor",
],
)
tf_cuda_library( tf_cuda_library(
name = "kernels", name = "kernels",
srcs = [ srcs = [
@ -564,6 +577,16 @@ tf_cuda_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "ops_hdrs",
hdrs = ["ops.h"],
visibility = ["//tensorflow:internal"],
deps = [
":tf_datatype",
":tf_status",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Tests # Tests

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h" // NOLINT #include "tensorflow/core/platform/platform.h" // NOLINT
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/framework/scope_internal.h"
@ -2606,4 +2607,14 @@ void TF_RegisterLogListener(void (*listener)(const char*)) {
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
} }
void TF_RegisterFilesystemPlugin(const char* plugin_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"FileSystem plugin functionality is not supported on mobile");
#else
status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
}
} // end extern "C" } // end extern "C"

View File

@ -1205,7 +1205,7 @@ typedef struct TF_Session TF_Session;
// Return a new execution session with the associated graph, or NULL on // Return a new execution session with the associated graph, or NULL on
// error. Does not take ownership of any input parameters. // error. Does not take ownership of any input parameters.
// //
// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be // *`graph` must be a valid graph (not deleted or nullptr). `graph` will be
// kept alive for the lifetime of the returned TF_Session. New nodes can still // kept alive for the lifetime of the returned TF_Session. New nodes can still
// be added to `graph` after this call. // be added to `graph` after this call.
TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph,
@ -1577,6 +1577,13 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
TF_CAPI_EXPORT extern void TF_RegisterLogListener( TF_CAPI_EXPORT extern void TF_RegisterLogListener(
void (*listener)(const char*)); void (*listener)(const char*));
// Register a FileSystem plugin from filename `plugin_filename`.
//
// On success, place OK in status.
// On failure, place an error status in status.
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
const char* plugin_filename, TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
collective_executor_handle->get()->StartAbort(status->status); collective_executor_handle->get()->StartAbort(status->status);
} }
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
const char* task, TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle(); auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done; tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth( collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, [&done, status](const Status& s) { task, timeout_in_ms, [&done, status](const Status& s) {
status->status = s; status->status = s;
done.Notify(); done.Notify();
}); });

View File

@ -86,7 +86,7 @@ TF_CAPI_EXPORT void TF_SetXlaConstantFoldingDisabled(
// Create a serialized tensorflow.ConfigProto proto, where: // Create a serialized tensorflow.ConfigProto proto, where:
// //
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if // a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if
// `enable_xla_compilation` is non-zero, and OFF otherwise. // `enable_xla_compilation` is non-zero, and OFF otherwise.
// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. // b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
// c) ConfigProto.device_count is set to `num_cpu_devices`. // c) ConfigProto.device_count is set to `num_cpu_devices`.
@ -241,8 +241,8 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
// Checks the health of collective ops peers. Explicit health check is needed in // Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is // multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang. // down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
const char* task, TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status); TF_Status* status);
// Information about the shape of a Tensor and its type. // Information about the shape of a Tensor and its type.

View File

@ -10,6 +10,9 @@ load(
"tf_cuda_library", "tf_cuda_library",
) )
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow:tensorflow.bzl", "filegroup")
@ -94,6 +97,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
"//tensorflow/core:gpu_runtime", "//tensorflow/core:gpu_runtime",
] + internal_tfrt_deps(), ] + internal_tfrt_deps(),
alwayslink = 1, alwayslink = 1,
@ -106,6 +110,7 @@ filegroup(
"abstract_function.h", "abstract_function.h",
"abstract_operation.h", "abstract_operation.h",
"abstract_tensor_handle.h", "abstract_tensor_handle.h",
"c_api.h",
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"c_api_unified_experimental.h", "c_api_unified_experimental.h",
@ -638,6 +643,19 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "tfe_tensorhandle_internal_hdrs_only",
extra_deps = [
"@com_google_absl//absl/strings",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tfe_tensorhandle_internal",
],
)
tf_cuda_library( tf_cuda_library(
name = "c_api_test_util", name = "c_api_test_util",
testonly = 1, testonly = 1,

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// environment, a traced representation etc. // environment, a traced representation etc.
class AbstractContext { class AbstractContext {
protected: protected:
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape }; enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {} virtual ~AbstractContext() {}

View File

@ -30,7 +30,14 @@ namespace tensorflow {
// tracing or immediate execution mode. // tracing or immediate execution mode.
class AbstractOperation { class AbstractOperation {
protected: protected:
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape }; enum AbstractOperationKind {
kGraph,
kMlir,
kEager,
kTfrt,
kTape,
kOpHandler
};
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {} virtual ~AbstractOperation() {}

View File

@ -70,6 +70,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
// TODO(yuefengz): support partially specified `worker_name`. tensorflow::GrpcServer* grpc_server =
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client; dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
status->status = context->GetClient(worker_name, &eager_client); if (grpc_server == nullptr) {
if (!status->status.ok()) { status->status =
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
return false;
}
tensorflow::WorkerInterface* wi =
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
if (wi == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Unable to find worker interface corresponding to task ", worker_name);
return false; return false;
} }
// Send a rpc request to the worker to check aliveness. tensorflow::GetStatusRequest request;
tensorflow::eager::KeepAliveRequest request; tensorflow::GetStatusResponse response;
request.set_context_id(context->GetContextId()); tensorflow::Status remote_status;
tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status;
tensorflow::Notification done; tensorflow::Notification done;
eager_client->KeepAliveAsync( wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
&request, &response, [&remote_status, &done](const tensorflow::Status& s) {
[&keep_alive_status, &done](const tensorflow::Status& s) { remote_status = s;
keep_alive_status = s;
done.Notify(); done.Notify();
}); });
done.WaitForNotification(); done.WaitForNotification();
// We set OK status so the call does not raise any exceptions. Instead, caller
// users the return value to tell if the remote worker is alive.
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
// If `context_id` doesn't exist on the remote worker, an InvalidArgument if (remote_status.ok()) {
// error will return. But this still indicates that the remote worker is
// alive.
if (keep_alive_status.ok() ||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
return true; return true;
} else {
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << keep_alive_status.error_message();
return false;
} }
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << remote_status.error_message();
return false;
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }
@ -1445,13 +1447,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = auto* context = tensorflow::unwrap(ctx);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->AsyncWait();
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return; if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*context->MetadataMu()); auto run_metadata = context->ExportRunMetadata();
status->status = MessageToBuffer(*context->RunMetadataProto(), buf); status->status = MessageToBuffer(*run_metadata, buf);
context->ClearRunMetadata();
} }
namespace { namespace {

View File

@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
} }
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return tensorflow::unwrap(h)->DeviceType(&status->status);
}
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return tensorflow::unwrap(h)->DeviceId(&status->status);
}

View File

@ -481,7 +481,7 @@ typedef struct TFE_CustomDevice {
// "/job:localhost/replica:0/task:0/device:CUSTOM:0". // "/job:localhost/replica:0/task:0/device:CUSTOM:0".
// //
// The custom device defines copy operations for moving TensorHandles on and // The custom device defines copy operations for moving TensorHandles on and
// off, and an an execution operation for named operations. Often execution will // off, and an execution operation for named operations. Often execution will
// simply wrap op execution on one or more physical devices. // simply wrap op execution on one or more physical devices.
// //
// device_info is an opaque caller-defined type stored with the custom device // device_info is an opaque caller-defined type stored with the custom device
@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable, unsigned char enable,
TF_Status* status); TF_Status* status);
// Returns the device type of the operation that produced `h`.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
TFE_TensorHandle* h, TF_Status* status);
// Returns the device ID of the operation that produced `h`.
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
TEST(CAPI, TensorHandleNullptr) {
TFE_TensorHandle* h = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_type, nullptr);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int device_id = TFE_TensorHandleDeviceID(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_id, -1);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
}
TEST(CAPI, TensorHandleDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_DeleteOp(shape_op);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleDefaults) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
h_default, ctx, "/device:CPU:0", status.get());
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
TFE_DeleteTensorHandle(h_default);
TFE_DeleteTensorHandle(h_cpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,9 +16,9 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ #define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
// Run a function containing a MatMul op and check its output. // Run a function containing a MatMul op and check its output.
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one // If heavy_load_on_streaming_rpc is true, send some rpc requests before the one
// which creates a remote remote input, to simulate a scenario that the remote // which creates a remote input, to simulate a scenario that the remote input
// input is not ready when we start running an op or a function. // is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc, bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false); bool remote_func_outputs = false);

View File

@ -696,7 +696,7 @@ TEST(CAPI, ExecuteAddForwardAsync) {
/*tfrt*/ false); /*tfrt*/ false);
} }
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE
// TODO(b/153349425): Add add forwarding tests for TFRT // TODO(b/153349425): Add forwarding tests for TFRT
TEST(CAPI, ExecuteAddTfrt) { TEST(CAPI, ExecuteAddTfrt) {
ExecuteAdd( ExecuteAdd(
/*async=*/false, /*async=*/false,
@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
EXPECT_NE(TF_OK, TF_GetCode(status)); EXPECT_NE(TF_OK, TF_GetCode(status));
EXPECT_EQ(nullptr, t); EXPECT_EQ(nullptr, t);
const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
<< TF_Message(status); << TF_Message(status);
// Since error is not cleared, the following copy with correct device will // Since error is not cleared, the following copy with correct device will

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
} }
TEST_P(GradientCheckerTest, TestGradCheckMatMul) { TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
// Computing numerical gradients with TensorFloat-32 is numerically unstable
enable_tensor_float_32_execution(false);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx; AbstractContextPtr ctx;

View File

@ -63,6 +63,8 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
return Status::OK(); return Status::OK();
} }
@ -74,11 +76,11 @@ Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1); std::vector<AbstractTensorHandle*> add_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
absl::MakeSpan(add_outputs), absl::MakeSpan(add_outputs),
"Add")); // Compute x+y. "Add")); // Compute x+y.
@ -97,7 +99,6 @@ Status AddGradModel(AbstractContext* ctx,
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
outputs[1] = out_grads[1]; outputs[1] = out_grads[1];
delete tape;
return Status::OK(); return Status::OK();
} }
@ -109,10 +110,10 @@ Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1); std::vector<AbstractTensorHandle*> exp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp")); ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
@ -128,7 +129,6 @@ Status ExpGradModel(AbstractContext* ctx,
exp_output->Unref(); exp_output->Unref();
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
delete tape;
return Status::OK(); return Status::OK();
} }
@ -140,10 +140,10 @@ Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1); std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt")); ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
@ -159,7 +159,6 @@ Status SqrtGradModel(AbstractContext* ctx,
sqrt_output->Unref(); sqrt_output->Unref();
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
delete tape;
return Status::OK(); return Status::OK();
} }
@ -172,12 +171,12 @@ Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); tape->Watch(ToId(inputs[0]));
tape->Watch(ToId(inputs[1])); tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2); vector<AbstractTensorHandle*> identity_n_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::IdentityN( TF_RETURN_IF_ERROR(ops::IdentityN(
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN")); tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
@ -195,7 +194,6 @@ Status IdentityNGradModel(AbstractContext* ctx,
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
outputs[1] = out_grads[1]; outputs[1] = out_grads[1];
delete tape;
return Status::OK(); return Status::OK();
} }
@ -207,11 +205,11 @@ Status NegGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); tape->Watch(ToId(inputs[0]));
std::vector<AbstractTensorHandle*> neg_outputs(1); std::vector<AbstractTensorHandle*> neg_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg")); ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
@ -227,6 +225,74 @@ Status NegGradModel(AbstractContext* ctx,
neg_output->Unref(); neg_output->Unref();
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
return Status::OK();
}
// Computes
// y = inputs[0] - inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status SubGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> sub_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
absl::MakeSpan(sub_outputs),
"Sub")); // Compute x-y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sub_output : sub_outputs) {
sub_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
return Status::OK();
}
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> mul_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
absl::MakeSpan(mul_outputs),
"Mul")); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(mul_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto mul_output : mul_outputs) {
mul_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape; delete tape;
return Status::OK(); return Status::OK();
} }
@ -612,6 +678,128 @@ TEST_P(CppGradients, TestNegGrad) {
result_tensor = nullptr; result_tensor = nullptr;
} }
TEST_P(CppGradients, TestSubGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x - y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestMulGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x * y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 2.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestSetAttrString) { TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -651,7 +839,7 @@ TEST_P(CppGradients, TestSetAttrString) {
int num_retvals = 1; int num_retvals = 1;
std::vector<AbstractTensorHandle*> outputs(1); std::vector<AbstractTensorHandle*> outputs(1);
GradientRegistry registry; GradientRegistry registry;
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false)); auto tape = std::make_unique<Tape>(/*persistent=*/false);
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
&num_retvals, &forward_op, tape.get(), registry); &num_retvals, &forward_op, tape.get(), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
@ -124,6 +125,13 @@ class ImmediateExecutionContext : public AbstractContext {
// Returns the device placement policy for the current thread. // Returns the device placement policy for the current thread.
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
// Return the collected RunMetadata. This method will transfer the ownership
// to the caller.
virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
// For LLVM style RTTI. // For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) { static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt; return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
@ -149,9 +157,6 @@ class ImmediateExecutionContext : public AbstractContext {
// Update the Eager Executor for current thread. // Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0; virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
protected: protected:
explicit ImmediateExecutionContext(AbstractContextKind kind) explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {} : AbstractContext(kind) {}

View File

@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
virtual const char* DeviceName(Status* status) const = 0; virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed. // Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(Status* status) const = 0; virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns the device type which created the handle.
virtual const char* DeviceType(Status* status) const = 0;
// Returns the device ID which created the handle.
virtual int DeviceId(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied. // Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual AbstractTensorInterface* Resolve(Status* status) = 0; virtual AbstractTensorInterface* Resolve(Status* status) = 0;

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -43,6 +44,11 @@ class CppGradients
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get()); Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message(); CHECK_EQ(errors::OK, s.code()) << s.error_message();
// Computing numerical gradients with TensorFloat-32 is numerically
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
// low tolerances
enable_tensor_float_32_execution(false);
} }
}; };

View File

@ -182,9 +182,8 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
std::string cacheKey(scheme); std::string cacheKey(scheme);
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
if (scheme == "file") { if (scheme == "file") {
libhdfs->hdfsBuilderSetNameNode(builder, nullptr); namenode = "";
} else if (scheme == "viewfs") { } else if (scheme == "viewfs") {
char* defaultFS = nullptr; char* defaultFS = nullptr;
libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS); libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS);
@ -200,24 +199,27 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
// The default NameNode configuration will be used (from the XML // The default NameNode configuration will be used (from the XML
// configuration files). See: // configuration files). See:
// https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259 // https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259
libhdfs->hdfsBuilderSetNameNode(builder, "default"); namenode = "default";
} else if (scheme == "har") { } else if (scheme == "har") {
std::string path_har = path; std::string path_har = path;
SplitArchiveNameAndPath(&path_har, &namenode, status); SplitArchiveNameAndPath(&path_har, &namenode, status);
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
cacheKey += namenode;
} else { } else {
libhdfs->hdfsBuilderSetNameNode( if (namenode.empty()) {
builder, namenode.empty() ? "default" : namenode.c_str()); namenode = "default";
cacheKey += namenode;
} }
}
cacheKey += namenode;
absl::MutexLock l(&hadoop_file->connection_cache_lock); absl::MutexLock l(&hadoop_file->connection_cache_lock);
if (hadoop_file->connection_cache.find(cacheKey) == if (hadoop_file->connection_cache.find(cacheKey) ==
hadoop_file->connection_cache.end()) { hadoop_file->connection_cache.end()) {
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
libhdfs->hdfsBuilderSetNameNode(
builder, namenode.empty() ? nullptr : namenode.c_str());
auto cacheFs = libhdfs->hdfsBuilderConnect(builder); auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
if (cacheFs == nullptr) { if (cacheFs == nullptr) {
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); TF_SetStatusFromIOError(status, TF_ABORTED, strerror(errno));
return cacheFs; return cacheFs;
} }
hadoop_file->connection_cache[cacheKey] = cacheFs; hadoop_file->connection_cache[cacheKey] = cacheFs;

View File

@ -221,6 +221,74 @@ class NegGradientFunction : public GradientFunction {
~NegGradientFunction() override {} ~NegGradientFunction() override {}
}; };
class SubGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a Sub op A-B, the gradients are:
*
* dA = U
* dB = -U
*
*/
grad_outputs->resize(2);
// Grad for A
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
(*grad_outputs)[0]->Ref();
// Grad for B
// negate the upstream grad
std::vector<AbstractTensorHandle*> neg_outputs(1);
std::string name = "Neg_Sub_Grad_B";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(neg_outputs), name.c_str()));
(*grad_outputs)[1] = neg_outputs[0];
return Status::OK();
}
~SubGradientFunction() override {}
};
class MulGradientFunction : public GradientFunction {
public:
explicit MulGradientFunction(vector<AbstractTensorHandle*> f_inputs)
: forward_inputs(f_inputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a mul op A*B, the gradients are:
*
* dA = U * B
* dB = A * U
*
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> mul_outputs(1);
// Gradient for A
std::string name = "Mul_Grad_A";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {upstream_grad, forward_inputs[1]},
absl::MakeSpan(mul_outputs), name.c_str()));
(*grad_outputs)[0] = mul_outputs[0];
// Gradient for B
name = "Mul_Grad_B";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {forward_inputs[0], upstream_grad},
absl::MakeSpan(mul_outputs), name.c_str()));
(*grad_outputs)[1] = mul_outputs[0];
return Status::OK();
}
~MulGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_inputs;
};
} // namespace } // namespace
BackwardFunction* AddRegisterer(const ForwardOperation& op) { BackwardFunction* AddRegisterer(const ForwardOperation& op) {
@ -268,5 +336,23 @@ BackwardFunction* NegRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients); return new BackwardFunction(gradient_function, default_gradients);
} }
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new SubGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* MulRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new MulGradientFunction(op.inputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow

View File

@ -25,6 +25,8 @@ BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op); BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
BackwardFunction* NegRegisterer(const ForwardOperation& op); BackwardFunction* NegRegisterer(const ForwardOperation& op);
BackwardFunction* SubRegisterer(const ForwardOperation& op);
BackwardFunction* MulRegisterer(const ForwardOperation& op);
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow

View File

@ -25,7 +25,7 @@ TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
parent_op_(parent_op), parent_op_(parent_op),
tape_(tape), tape_(tape),
registry_(registry) { registry_(registry) {
// TODO(srbs): Make AbstractOperation RefCounted. // TODO(b/172003047): Consider making AbstractOperation RefCounted.
// parent_op_->Ref(); // parent_op_->Ref();
} }
void TapeOperation::Release() { void TapeOperation::Release() {
@ -33,7 +33,7 @@ void TapeOperation::Release() {
delete this; delete this;
} }
TapeOperation::~TapeOperation() { TapeOperation::~TapeOperation() {
// TODO(srbs): Make AbstractOperation RefCounted. // TODO(b/172003047): Consider making AbstractOperation RefCounted.
// parent_op->Unref(); // parent_op->Unref();
} }
Status TapeOperation::Reset(const char* op, const char* raw_device_name) { Status TapeOperation::Reset(const char* op, const char* raw_device_name) {

View File

@ -0,0 +1,43 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_test(
name = "internal_test",
srcs = ["internal_test.cc"],
deps = [
":internal",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "internal",
srcs = ["internal.cc"],
hdrs = ["internal.h"],
deps = [
":wrapper_operation",
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "wrapper_operation",
srcs = ["wrapper_operation.cc"],
hdrs = ["wrapper_operation.h"],
deps = ["//tensorflow/c/eager:abstract_operation"],
)

View File

@ -0,0 +1,79 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
#include "tensorflow/c/experimental/op_handler/internal.h"
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
OpHandlerContext::OpHandlerContext(AbstractContext* parent_ctx)
: AbstractContext(kOpHandler), parent_ctx_(parent_ctx) {}
OpHandlerContext::~OpHandlerContext() {}
void OpHandlerContext::Release() { delete this; }
Status OpHandlerContext::RegisterFunction(AbstractFunction* function) {
return parent_ctx_->RegisterFunction(function);
}
Status OpHandlerContext::RemoveFunction(const string& function) {
return parent_ctx_->RemoveFunction(function);
}
void OpHandlerContext::set_default_handler(OpHandler* handler) {
handler->Ref();
default_handler_.reset(handler);
}
OpHandlerOperation* OpHandlerContext::CreateOperation() {
OpHandlerOperation* result =
new OpHandlerOperation(parent_ctx_->CreateOperation());
if (default_handler_ != nullptr) {
result->set_handler(default_handler_.get());
}
return result;
}
OpHandlerOperation::OpHandlerOperation(AbstractOperation* parent_op)
: WrapperOperation(parent_op, kOpHandler) {}
OpHandler* OpHandlerOperation::get_handler() { return handler_.get(); }
void OpHandlerOperation::set_handler(OpHandler* handler) {
if (handler != nullptr) {
handler->Ref();
}
handler_.reset(handler);
}
Status OpHandlerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
if (handler_ == nullptr) {
return WrapperOperation::Execute(retvals, num_retvals);
} else {
return handler_->Execute(this, retvals, num_retvals);
}
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class OpHandlerOperation;
// Op handlers are a convenient way to intercept and transform computation.
//
// The implementation is currently experimental and incomplete, but aims
// eventually to support tracing and replay of function bodies, gradients
// through copy operations, and a variety of hooks for things like debug
// strings. A public C API for op handlers is planned.
class OpHandler : public core::RefCounted {
public:
// Called on operation->Execute when operation->get_handler() == this.
//
// Allows the handler to customize or inspect `operation`'s execution.
virtual Status Execute(OpHandlerOperation* operation,
absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) = 0;
// Creates a new handler by merging this handler with `next_handler`.
//
// The new handler is expected to transform operations first with this handler
// and then execute the resulting operations on `next_handler` (by calling
// `OpHandlerOperation::set_handler` and passing `next_handler`). If this is
// not possible then the merge operation should fail.
virtual Status Merge(OpHandler* next_handler,
core::RefCountPtr<OpHandler>& merged_handler) = 0;
};
// Keeps some handler-specific metadata, but otherwise wraps a single
// AbstractOperation in the underlying context. The operation is created, its
// attributes set, etc., and at execution time it is presented to its handler,
// which may choose to execute it or simply inspect it and do something else.
//
// This is somewhat different than the Context approach, where the operation's
// construction is streamed through each layered Context. The streaming approach
// would require a much larger op handler public API, one function pointer per
// attribute type, and there is some ambiguity before an op is finalized about
// whether it should be presented as-is to handlers (regular operations) or
// replayed (function calls and control flow operations).
class OpHandlerOperation : public WrapperOperation {
public:
explicit OpHandlerOperation(AbstractOperation*);
OpHandler* get_handler();
void set_handler(OpHandler* handler);
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
protected:
core::RefCountPtr<OpHandler> handler_;
};
// A context which allows a default handler to be set for new operations. It
// otherwise defers to the context it wraps.
//
// TODO(allenl): A stack of contexts and a stack of handlers look pretty similar
// in some ways. Having each handler be its own context seems almost doable,
// with things like copy operations and function/control flow replay being
// somewhat tricky (since they should be generated at the top of the handler
// stack and "caught" at the bottom). After handlers have evolved for a bit we
// should re-evaluate whether the handler+context concepts can be merged.
class OpHandlerContext : public AbstractContext {
public:
explicit OpHandlerContext(AbstractContext*);
void Release() override;
OpHandlerOperation* CreateOperation() override;
Status RegisterFunction(AbstractFunction*) override;
Status RemoveFunction(const string&) override;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kOpHandler;
}
~OpHandlerContext() override;
void set_default_handler(OpHandler* handler);
private:
AbstractContext* parent_ctx_; // Not owned.
core::RefCountPtr<OpHandler> default_handler_;
};
class ReleaseOpHandlerOperation {
public:
void operator()(OpHandlerOperation* operation) { operation->Release(); }
};
typedef std::unique_ptr<OpHandlerOperation, ReleaseOpHandlerOperation>
OpHandlerOperationPtr;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_

View File

@ -0,0 +1,102 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/op_handler/internal.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
class TestOpHandler : public OpHandler {
public:
TestOpHandler() : last_operation_(new std::string("")) {}
Status Execute(OpHandlerOperation* operation,
absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override {
CHECK(operation->get_handler() == this);
*last_operation_ = operation->Name();
operation->set_handler(next_handler_.get());
return operation->Execute(retvals, num_retvals);
}
Status Merge(OpHandler* next_handler,
core::RefCountPtr<OpHandler>& merged_handler) override {
merged_handler.reset(new TestOpHandler(next_handler, last_operation_));
return Status::OK();
}
core::RefCountPtr<OpHandler> next_handler_ = nullptr;
// Shared between merged handlers of this type.
std::shared_ptr<std::string> last_operation_;
private:
TestOpHandler(OpHandler* next_handler,
std::shared_ptr<std::string> last_operation)
: next_handler_(next_handler), last_operation_(last_operation) {
next_handler->Ref();
}
};
TEST(INTERNAL_TEST, UseOpHandler) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_ExecutionContext, decltype(&TF_DeleteExecutionContext)>
c_ctx(TF_NewEagerExecutionContext(opts.get(), status.get()),
TF_DeleteExecutionContext);
OpHandlerContext ctx(unwrap(c_ctx.get()));
core::RefCountPtr<TestOpHandler> outer_handler(new TestOpHandler());
core::RefCountPtr<TestOpHandler> inner_handler(new TestOpHandler());
ctx.set_default_handler(outer_handler.get());
OpHandlerOperationPtr op(ctx.CreateOperation());
Status s = op->Reset("NoOp", "");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> retvals;
int num_retvals = 0;
EXPECT_EQ("", *outer_handler->last_operation_);
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
*outer_handler->last_operation_ = "";
EXPECT_EQ("", *inner_handler->last_operation_);
// This op executes on both handlers, changing the state of `inner_handler`
// since the handler has decided to preserve that state across merges.
core::RefCountPtr<OpHandler> merged;
s = inner_handler->Merge(outer_handler.get(), merged);
ctx.set_default_handler(merged.get());
op.reset(ctx.CreateOperation());
s = op->Reset("NoOp", "");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ("NoOp", *inner_handler->last_operation_);
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
inner_handler.reset();
outer_handler.reset();
op.reset(ctx.CreateOperation());
s = op->Reset("NoOp", "");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
} // namespace tensorflow

View File

@ -0,0 +1,120 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
namespace tensorflow {
WrapperOperation::WrapperOperation(AbstractOperation* parent_op,
AbstractOperationKind kind)
: AbstractOperation(kind), parent_op_(parent_op) {
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
// parent_op_->Ref();
}
void WrapperOperation::Release() {
parent_op_->Release();
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
delete this;
}
Status WrapperOperation::Reset(const char* op, const char* raw_device_name) {
return parent_op_->Reset(op, raw_device_name);
}
const string& WrapperOperation::Name() const { return parent_op_->Name(); }
const string& WrapperOperation::DeviceName() const {
return parent_op_->DeviceName();
}
Status WrapperOperation::SetDeviceName(const char* name) {
return parent_op_->SetDeviceName(name);
}
Status WrapperOperation::AddInput(AbstractTensorHandle* input) {
return parent_op_->AddInput(input);
}
Status WrapperOperation::AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) {
return parent_op_->AddInputList(inputs);
}
Status WrapperOperation::SetAttrString(const char* attr_name, const char* data,
size_t length) {
return parent_op_->SetAttrString(attr_name, data, length);
}
Status WrapperOperation::SetAttrInt(const char* attr_name, int64_t value) {
return parent_op_->SetAttrInt(attr_name, value);
}
Status WrapperOperation::SetAttrFloat(const char* attr_name, float value) {
return parent_op_->SetAttrFloat(attr_name, value);
}
Status WrapperOperation::SetAttrBool(const char* attr_name, bool value) {
return parent_op_->SetAttrBool(attr_name, value);
}
Status WrapperOperation::SetAttrType(const char* attr_name, DataType value) {
return parent_op_->SetAttrType(attr_name, value);
}
Status WrapperOperation::SetAttrShape(const char* attr_name,
const int64_t* dims, const int num_dims) {
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
}
Status WrapperOperation::SetAttrFunction(const char* attr_name,
const AbstractOperation* value) {
return parent_op_->SetAttrFunction(attr_name, value);
}
Status WrapperOperation::SetAttrFunctionName(const char* attr_name,
const char* value, size_t length) {
return parent_op_->SetAttrFunctionName(attr_name, value, length);
}
Status WrapperOperation::SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) {
return parent_op_->SetAttrTensor(attr_name, tensor);
}
Status WrapperOperation::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) {
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status WrapperOperation::SetAttrFloatList(const char* attr_name,
const float* values, int num_values) {
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrIntList(const char* attr_name,
const int64_t* values, int num_values) {
return parent_op_->SetAttrIntList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrTypeList(const char* attr_name,
const DataType* values,
int num_values) {
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims, int num_values) {
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status WrapperOperation::SetAttrFunctionList(
const char* attr_name, absl::Span<const AbstractOperation*> values) {
return parent_op_->SetAttrFunctionList(attr_name, values);
}
AbstractOperation* WrapperOperation::GetBackingOperation() {
return parent_op_;
}
Status WrapperOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
return parent_op_->Execute(retvals, num_retvals);
}
} // namespace tensorflow

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
#include "tensorflow/c/eager/abstract_operation.h"
namespace tensorflow {
// Forwards all of the AbstractOperation's methods to its wrapped operation.
//
// Useful as a base class to default to forwarding while adding some
// customization.
class WrapperOperation : public AbstractOperation {
public:
explicit WrapperOperation(AbstractOperation*, AbstractOperationKind kind);
void Release() override;
Status Reset(const char* op, const char* raw_device_name) override;
const string& Name() const override;
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override;
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override;
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperation*> values) override;
AbstractOperation* GetBackingOperation();
private:
AbstractOperation* parent_op_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_

View File

@ -146,6 +146,7 @@ cc_library(
":tf_signature_def_function", ":tf_signature_def_function",
":variable", ":variable",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
], ],
) )

View File

@ -343,7 +343,8 @@ Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx,
std::unique_ptr<TFConcreteFunction> out; std::unique_ptr<TFConcreteFunction> out;
TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn, TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn,
obj_graph, objects, &out)); obj_graph, objects, &out));
revived->concrete_functions[create_resource_fn->node_id] = std::move(out); revived->concrete_functions.Insert(std::move(out),
create_resource_fn->node_id);
} }
return Status(); return Status();
} }
@ -352,8 +353,6 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph, const SavedObjectGraph& obj_graph,
const PartiallyRevivedObjects& objects, const PartiallyRevivedObjects& objects,
RevivedObjects* revived) { RevivedObjects* revived) {
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>* destination_func_map =
&revived->concrete_functions;
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>* gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>*
destination_sig_map = &revived->signature_def_functions; destination_sig_map = &revived->signature_def_functions;
@ -361,7 +360,7 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
int node_id = id_and_func.first; int node_id = id_and_func.first;
const TFConcreteFunctionRevivalState& func = id_and_func.second; const TFConcreteFunctionRevivalState& func = id_and_func.second;
if (destination_func_map->find(node_id) != destination_func_map->end()) { if (revived->concrete_functions.Find(node_id)) {
// The function has already been initialized in the destination_map, // The function has already been initialized in the destination_map,
// so we can skip this node. This can occur because we initialize // so we can skip this node. This can occur because we initialize
// CreateResource functions before calling this function. // CreateResource functions before calling this function.
@ -371,7 +370,7 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
std::unique_ptr<TFConcreteFunction> out; std::unique_ptr<TFConcreteFunction> out;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
CreateConcreteFunction(ctx, func, obj_graph, objects, &out)); CreateConcreteFunction(ctx, func, obj_graph, objects, &out));
(*destination_func_map)[node_id] = std::move(out); revived->concrete_functions.Insert(std::move(out), node_id);
} }
for (const auto& id_and_func : objects.signature_def_functions) { for (const auto& id_and_func : objects.signature_def_functions) {
@ -398,20 +397,16 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
for (auto& id_and_resource : objects->restored_resources) { for (auto& id_and_resource : objects->restored_resources) {
RestoredResourceRevivalState& resource = id_and_resource.second; RestoredResourceRevivalState& resource = id_and_resource.second;
int create_resource_fn_node = resource.create_resource->node_id; int create_resource_fn_node = resource.create_resource->node_id;
const gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>&
revived_functions = revived->concrete_functions;
const auto& revived_functions_iter = const TFConcreteFunction* create_resource_fn =
revived_functions.find(create_resource_fn_node); revived->concrete_functions.Find(create_resource_fn_node);
if (revived_functions_iter == revived_functions.end()) { if (create_resource_fn == nullptr) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"ConcreteFunction at node ", create_resource_fn_node, "ConcreteFunction at node ", create_resource_fn_node,
" should have been initialized prior to being called."); " should have been initialized prior to being called.");
} }
const TFConcreteFunction& create_resource_fn =
*revived_functions_iter->second;
ImmediateOpPtr function_op; ImmediateOpPtr function_op;
TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op)); TF_RETURN_IF_ERROR(create_resource_fn->MakeCallOp({}, &function_op));
TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str())); TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str()));
AbstractTensorHandle* resource_handle = nullptr; AbstractTensorHandle* resource_handle = nullptr;
@ -431,21 +426,6 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
return Status(); return Status();
} }
// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to
// point to it. If node doesn't exist in `objects`, out is untouched, and an
// error status is returned.
Status FindConcreteFunction(int node, RevivedObjects* objects,
TFConcreteFunction** out) {
auto func_iter = objects->concrete_functions.find(node);
if (func_iter == objects->concrete_functions.end()) {
return errors::FailedPrecondition(
"Failed to find ConcreteFunction with node id ", node,
" in revived objects");
}
*out = func_iter->second.get();
return Status();
}
Status BuildResources(ImmediateExecutionContext* ctx, Status BuildResources(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph, const SavedObjectGraph& obj_graph,
PartiallyRevivedObjects* objects, PartiallyRevivedObjects* objects,
@ -460,22 +440,35 @@ Status BuildResources(ImmediateExecutionContext* ctx,
// Check all the functions associated with the resource have already been // Check all the functions associated with the resource have already been
// initialized in `revived` // initialized in `revived`
if (resource_revival_state.create_resource != nullptr) { if (resource_revival_state.create_resource != nullptr) {
TF_RETURN_IF_ERROR( create_resource = revived->concrete_functions.Find(
FindConcreteFunction(resource_revival_state.create_resource->node_id, resource_revival_state.create_resource->node_id);
revived, &create_resource)); if (create_resource == nullptr) {
return errors::FailedPrecondition(
"'create_resource' function with node id ",
resource_revival_state.create_resource->node_id, " not found");
}
} }
TFConcreteFunction* initialize = nullptr; TFConcreteFunction* initialize = nullptr;
if (resource_revival_state.initialize != nullptr) { if (resource_revival_state.initialize != nullptr) {
TF_RETURN_IF_ERROR(FindConcreteFunction( initialize = revived->concrete_functions.Find(
resource_revival_state.initialize->node_id, revived, &initialize)); resource_revival_state.initialize->node_id);
if (initialize == nullptr) {
return errors::FailedPrecondition(
"'initialize' function with node id ",
resource_revival_state.initialize->node_id, " not found");
}
} }
TFConcreteFunction* destroy_resource = nullptr; TFConcreteFunction* destroy_resource = nullptr;
if (resource_revival_state.destroy_resource != nullptr) { if (resource_revival_state.destroy_resource != nullptr) {
TF_RETURN_IF_ERROR( destroy_resource = revived->concrete_functions.Find(
FindConcreteFunction(resource_revival_state.destroy_resource->node_id, resource_revival_state.destroy_resource->node_id);
revived, &destroy_resource)); if (destroy_resource == nullptr) {
return errors::FailedPrecondition(
"'destroy_resource' function with node id ",
resource_revival_state.destroy_resource->node_id, " not found");
}
} }
if (resource_revival_state.resource_handle == nullptr) { if (resource_revival_state.resource_handle == nullptr) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
@ -29,6 +30,43 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// A container for revived saved model objects.
//
// Most of the objects will be revived from nodes in the object graph, and for
// those objects this container provides a map from node id to the revived
// objects.
//
// For objects that have to be revived but are not part of the object graph,
// this container provides a place where the objects can be stored so they are
// available to the runtime.
template <typename T>
class RevivedObjectContainer {
public:
// Insert an object that is not related to a node id. This usually means the
// object was not referenced by the object_graph, but is needed by other
// objects.
void Insert(std::unique_ptr<T> object) {
objects_.push_back(std::move(object));
}
// Insert an object that is tied to the given object graph node id.
void Insert(std::unique_ptr<T> object, int node_id) {
objects_by_id_[node_id] = object.get();
Insert(std::move(object));
}
// Find an object by the object graph node id.
// Returns nullptr if there is no such object.
T* Find(int node_id) {
auto it = objects_by_id_.find(node_id);
return it == objects_by_id_.end() ? nullptr : it->second;
}
private:
std::vector<std::unique_ptr<T>> objects_;
absl::flat_hash_map<int, T*> objects_by_id_;
};
// RevivedObjects is mainly used as a container for all the "state" owned by // RevivedObjects is mainly used as a container for all the "state" owned by
// SavedModel. It stores all non-"user object" nodes from a SavedModel // SavedModel. It stores all non-"user object" nodes from a SavedModel
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62) // (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62)
@ -37,12 +75,14 @@ namespace tensorflow {
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29) // (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29)
// to the revived object of the corresponding type. // to the revived object of the corresponding type.
struct RevivedObjects { struct RevivedObjects {
// Order of declaration is important here: we want the RestoredResources to be
// freed after TFConcreteFunctions, for example.
gtl::FlatMap<int, std::unique_ptr<Variable>> variables; gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
gtl::FlatMap<int, std::unique_ptr<Asset>> assets; gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
gtl::FlatMap<int, std::unique_ptr<Constant>> constants; gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>> concrete_functions;
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>> gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
signature_def_functions; signature_def_functions;
RevivedObjectContainer<TFConcreteFunction> concrete_functions;
gtl::FlatMap<int, RestoredResource> restored_resources; gtl::FlatMap<int, RestoredResource> restored_resources;
gtl::FlatMap<std::string, int> signatures_map; gtl::FlatMap<std::string, int> signatures_map;
}; };

View File

@ -46,8 +46,6 @@ class SavedModelAPI {
virtual Status GetSignatureDefFunction(const std::string& signature_def_key, virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
SignatureDefFunction** function) = 0; SignatureDefFunction** function) = 0;
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
virtual ~SavedModelAPI() = default; virtual ~SavedModelAPI() = default;
}; };

View File

@ -73,7 +73,6 @@ using FlatTensorFunctionMap =
namespace { namespace {
const TrackableObjectGraph::TrackableObject::SerializedTensor* const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable( FindSerializedTensorInTrackable(
const TrackableObjectGraph::TrackableObject& trackable_object, const TrackableObjectGraph::TrackableObject& trackable_object,
@ -181,12 +180,11 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
return errors::NotFound("No saved object found at path ", function_path); return errors::NotFound("No saved object found at path ", function_path);
} }
auto function_iter = revived_objects_.concrete_functions.find(*node); *function = revived_objects_.concrete_functions.Find(*node);
if (function_iter == revived_objects_.concrete_functions.end()) { if (*function == nullptr) {
return errors::NotFound("No function found at path ", function_path); return errors::NotFound("No function found at path ", function_path);
} }
*function = function_iter->second.get();
return Status(); return Status();
} }
@ -211,15 +209,6 @@ Status TFSavedModelAPI::GetSignatureDefFunction(
return Status(); return Status();
} }
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
std::vector<ConcreteFunction*> result;
result.reserve(revived_objects_.concrete_functions.size());
for (auto& index_and_function : revived_objects_.concrete_functions) {
result.push_back(index_and_function.second.get());
}
return result;
}
Status TFSavedModelAPI::GetVariable(const std::string& variable_path, Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
Variable** variable) { Variable** variable) {
absl::optional<int> node = absl::optional<int> node =
@ -263,10 +252,10 @@ Status TFSavedModelAPI::Load(
// This occurs in python here: // This occurs in python here:
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454 // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
// Step 1: For each node in the graph, we should initialize an object of the // For each node in the graph, we should initialize an object of the
// corresponding type. For objects that depend on the initialization of other // corresponding type. For objects that depend on the initialization of other
// objects (like functions which capture resources), we will initialize them // objects (like functions which capture resources), we will initialize them
// in step 2. // later.
PartiallyRevivedObjects partially_revived_objects; PartiallyRevivedObjects partially_revived_objects;
TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects( TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
bundle.meta_graph_def(), context, directory, &partially_revived_objects)); bundle.meta_graph_def(), context, directory, &partially_revived_objects));
@ -275,6 +264,22 @@ Status TFSavedModelAPI::Load(
TF_RETURN_IF_ERROR(partially_revived_objects.Build( TF_RETURN_IF_ERROR(partially_revived_objects.Build(
context, bundle.saved_object_graph(), &revived_objects)); context, bundle.saved_object_graph(), &revived_objects));
// Revive function library functions as concrete functions without captures.
// This is necessary because object graph functions may refer to functions
// _not_ in the object graph: A while loop, for example, will create two
// auxiliary `while_cond` and `while_body` functions that are only present in
// the graph def function library.
for (const FunctionDef& function :
bundle.meta_graph_def().graph_def().library().function()) {
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(TFConcreteFunction::Create(/*function_def=*/&function,
/*captures=*/{},
/*metadata=*/{},
/*ctx=*/context,
/*out=*/&concrete_function));
revived_objects.concrete_functions.Insert(std::move(concrete_function));
}
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
RestoreCheckpoint(&bundle, revived_objects, directory, context)); RestoreCheckpoint(&bundle, revived_objects, directory, context));

View File

@ -66,8 +66,6 @@ class TFSavedModelAPI : public SavedModelAPI {
ImmediateExecutionContext* context, ImmediateExecutionContext* context,
std::unique_ptr<TFSavedModelAPI>* out); std::unique_ptr<TFSavedModelAPI>* out);
std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPI() override = default; ~TFSavedModelAPI() override = default;
Status GetVariable(const std::string& variable_path, Variable** variable); Status GetVariable(const std::string& variable_path, Variable** variable);

View File

@ -122,9 +122,4 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
return tensorflow::wrap(result); return tensorflow::wrap(result);
} }
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{
tensorflow::unwrap(model)->ListFunctions()};
}
} // end extern "C" } // end extern "C"

View File

@ -524,6 +524,62 @@ TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);
} }
TEST_P(CSavedModelAPITest, LoadSavedModelWithWhileLoop) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(),
"c/experimental/saved_model/internal/testdata/SimpleWhileLoop");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_ConcreteFunction* while_fn =
TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
std::vector<TFE_TensorHandle*> while_fn_inputs;
while_fn_inputs.push_back(TestScalarTensorHandle(ctx, 10.0f));
TFE_Op* while_fn_op = TF_ConcreteFunctionMakeCallOp(
while_fn, while_fn_inputs.data(), while_fn_inputs.size(), status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* while_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(while_fn_op, &while_fn_outputs[0], &num_retvals, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(while_fn_outputs[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(TF_NumDims(result), 0);
float output_value = *static_cast<float*>(TF_TensorData(result));
ASSERT_FLOAT_EQ(output_value, 55); // 10+9+...+1
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(while_fn_outputs[0]);
TFE_DeleteOp(while_fn_op);
TFE_DeleteTensorHandle(while_fn_inputs[0]);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest, INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
::testing::Bool()); ::testing::Bool());

View File

@ -12,6 +12,8 @@ py_strict_binary(
srcs = ["gen_saved_models.py"], srcs = ["gen_saved_models.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
@ -21,7 +23,6 @@ py_strict_binary(
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/module", "//tensorflow/python/module",
"//tensorflow/python/saved_model", "//tensorflow/python/saved_model",
"//tensorflow/python/saved_model:save_options",
], ],
) )
@ -29,6 +30,7 @@ py_strict_binary(
filegroup( filegroup(
name = "saved_models", name = "saved_models",
srcs = glob([ srcs = glob([
"SimpleWhileLoop/**",
"UninitializedVariable/**", "UninitializedVariable/**",
]), ]),
visibility = [ visibility = [

View File

@ -30,9 +30,11 @@ import os
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module from tensorflow.python.module import module
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import app from tensorflow.python.platform import app
@ -72,11 +74,32 @@ def _gen_uninitialized_variable(base_dir):
to_save, export_dir=os.path.join(base_dir, "UninitializedVariable")) to_save, export_dir=os.path.join(base_dir, "UninitializedVariable"))
def _gen_simple_while_loop(base_dir):
"""Generates a saved model with a while loop."""
class Module(module.Module):
"""A module with a while loop."""
@def_function.function(
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
def compute(self, value):
acc, _ = control_flow_ops.while_loop(
cond=lambda acc, i: i > 0,
body=lambda acc, i: (acc + i, i - 1),
loop_vars=(constant_op.constant(0.0), value))
return acc
to_save = Module()
saved_model.save(
to_save, export_dir=os.path.join(base_dir, "SimpleWhileLoop"))
def main(args): def main(args):
if len(args) != 2: if len(args) != 2:
raise app.UsageError("Expected one argument (base_dir).") raise app.UsageError("Expected one argument (base_dir).")
_, base_dir = args _, base_dir = args
_gen_uninitialized_variable(base_dir) _gen_uninitialized_variable(base_dir)
_gen_simple_while_loop(base_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -100,11 +100,6 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
const char* signature_def_key, const char* signature_def_key,
TF_Status* status); TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
TF_SavedModel* model);
#ifdef __cplusplus #ifdef __cplusplus
} // end extern "C" } // end extern "C"
#endif // __cplusplus #endif // __cplusplus

View File

@ -11,11 +11,21 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
cc_library(
name = "stream_executor_hdrs",
hdrs = ["stream_executor.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status_headers",
],
)
cc_library( cc_library(
name = "stream_executor", name = "stream_executor",
srcs = ["stream_executor.cc"], srcs = ["stream_executor.cc"],
hdrs = ["stream_executor.h"], hdrs = ["stream_executor.h"],
visibility = ["//visibility:public"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
":stream_executor_internal", ":stream_executor_internal",
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",

View File

@ -148,7 +148,7 @@ void RegisterBitcastOpKernel() {
<< "Error while registering bitcast kernel"; << "Error while registering bitcast kernel";
} }
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
{ {
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU, auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU,
&BitcastOp_Create, &BitcastOp_Compute, &BitcastOp_Create, &BitcastOp_Compute,

View File

@ -44,7 +44,7 @@ class DummyDevice : public DeviceBase {
} }
}; };
// Helper for comparing ouput and expected output // Helper for comparing output and expected output
void ExpectSummaryMatches(const Summary& actual, const string& expected_str) { void ExpectSummaryMatches(const Summary& actual, const string& expected_str) {
Summary expected; Summary expected;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected)); ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected));

View File

@ -352,7 +352,7 @@ class DeviceKernelOpTest : public OpsTestBase {
EXPECT_EQ(TF_OK, TF_GetCode(status)); EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
std::unique_ptr<Device> device( std::unique_ptr<Device> device(
DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0")); DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0"));
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
@ -361,7 +361,7 @@ class DeviceKernelOpTest : public OpsTestBase {
TF_ASSERT_OK(InitOp()); TF_ASSERT_OK(InitOp());
} }
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const char* device_name_ = tensorflow::DEVICE_GPU; const char* device_name_ = tensorflow::DEVICE_GPU;
#else #else
const char* device_name_ = tensorflow::DEVICE_CPU; const char* device_name_ = tensorflow::DEVICE_CPU;
@ -468,7 +468,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
int64_t dim = 1; int64_t dim = 1;
TF_AllocatorAttributes alloc_attrs; TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
alloc_attrs.on_host = 0; alloc_attrs.on_host = 0;
#else #else
alloc_attrs.on_host = 1; alloc_attrs.on_host = 1;
@ -505,7 +505,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
int64_t dim = 0; int64_t dim = 0;
TF_AllocatorAttributes alloc_attrs; TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
alloc_attrs.on_host = 0; alloc_attrs.on_host = 0;
#else #else
alloc_attrs.on_host = 1; alloc_attrs.on_host = 1;
@ -538,7 +538,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
int64_t dim[2] = {2, 3}; int64_t dim[2] = {2, 3};
TF_AllocatorAttributes alloc_attrs; TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
alloc_attrs.on_host = 0; alloc_attrs.on_host = 0;
#else #else
alloc_attrs.on_host = 1; alloc_attrs.on_host = 1;
@ -646,7 +646,7 @@ template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes, void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx) { TF_OpKernelContext* ctx) {
T* data = reinterpret_cast<T*>(TF_TensorData(tensor)); T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx); OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values, cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
tensor_size_bytes); tensor_size_bytes);

View File

@ -76,7 +76,7 @@ class Tensor {
// unknown rank. // unknown rank.
int dims() const; int dims() const;
// Returns the number of elements in in demension `d`. // Returns the number of elements in dimension `d`.
// REQUIRES: `0 <= d < dims()` // REQUIRES: `0 <= d < dims()`
int64_t dim_size(int d) const; int64_t dim_size(int d) const;
@ -154,7 +154,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
// 1. Only a function pointer is sent across the C API (&DeleterFunction) // 1. Only a function pointer is sent across the C API (&DeleterFunction)
// 2. DeleterFunction is defined in the same build artifact that constructed // 2. DeleterFunction is defined in the same build artifact that constructed
// the std::function (so there isn't confusion about std::function ABI). // the std::function (so there isn't confusion about std::function ABI).
// Note that 2. is satisifed by the fact that this is a header-only API, where // Note that 2. is satisfied by the fact that this is a header-only API, where
// the function implementations are inline. // the function implementations are inline.
DeleterStruct* deleter_struct = new DeleterStruct{deleter}; DeleterStruct* deleter_struct = new DeleterStruct{deleter};

View File

@ -67,7 +67,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
// mat: A 2-D tensor of dimension [D0, D1] // mat: A 2-D tensor of dimension [D0, D1]
// //
// Returns: // Returns:
// A tensor of dimension [D0, D1], the result fo vec * mat. // A tensor of dimension [D0, D1], the result for vec * mat.
Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) { Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
auto reshaped = ExpandDims(scope, vec, -1); auto reshaped = ExpandDims(scope, vec, -1);
return Multiply(scope, reshaped, mat); return Multiply(scope, reshaped, mat);

View File

@ -84,9 +84,6 @@ class SavedModelAPI {
SignatureDefFunction* GetSignatureDefFunction( SignatureDefFunction* GetSignatureDefFunction(
const std::string& function_path, Status* status); const std::string& function_path, Status* status);
// Lists all Conrete Functions available from the SavedModel.
std::vector<ConcreteFunction*> ListFunctions();
// SavedModelAPI is movable, but not copyable. // SavedModelAPI is movable, but not copyable.
SavedModelAPI(SavedModelAPI&&) = default; SavedModelAPI(SavedModelAPI&&) = default;
SavedModelAPI& operator=(SavedModelAPI&&) = default; SavedModelAPI& operator=(SavedModelAPI&&) = default;
@ -151,11 +148,6 @@ inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction(
return SignatureDefFunction::wrap(function); return SignatureDefFunction::wrap(function);
} }
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
return list.ToVector();
}
} // namespace cc } // namespace cc
} // namespace experimental } // namespace experimental
} // namespace tensorflow } // namespace tensorflow

View File

@ -404,10 +404,12 @@ Status RestoreSession(const RunOptions& run_options,
const uint64 read_start_microseconds = Env::Default()->NowMicros(); const uint64 read_start_microseconds = Env::Default()->NowMicros();
std::vector<AssetFileDef> asset_file_defs; std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs)); TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
if (meta_graph.has_saver_def()) {
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
meta_graph.saver_def().restore_op_name(), meta_graph.saver_def().restore_op_name(),
meta_graph.saver_def().filename_tensor_name(), meta_graph.saver_def().filename_tensor_name(),
asset_file_defs, session->get())); asset_file_defs, session->get()));
}
// Record walltime spent in restoring graph from disk, but postpone metric // Record walltime spent in restoring graph from disk, but postpone metric
// increments until graph init finishes. // increments until graph init finishes.
const uint64 restore_graph_walltime = const uint64 restore_graph_walltime =

View File

@ -138,7 +138,7 @@ class FreezeTest : public ::testing::Test {
} }
TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
// "c" isnt dependent on the variable, so nothing should be frozen. // "c" isn't dependent on the variable, so nothing should be frozen.
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
graph_def, {"c:0"}, "assign", &saved_model_bundle)); graph_def, {"c:0"}, "assign", &saved_model_bundle));
@ -183,7 +183,7 @@ class FreezeTest : public ::testing::Test {
} }
Output c = ops::Mul(scope.WithOpName("c"), a, read_var); Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
// "c" isnt dependent on the variable, so nothing should be frozen. // "c" isn't dependent on the variable, so nothing should be frozen.
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
graph_def, {"c:0"}, "assign", &saved_model_bundle)); graph_def, {"c:0"}, "assign", &saved_model_bundle));
@ -244,7 +244,7 @@ class FreezeTest : public ::testing::Test {
Output c = ops::Mul(scope.WithOpName("c"), a, read_var); Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
// "c" isnt dependent on the variable, so nothing should be frozen. // "c" isn't dependent on the variable, so nothing should be frozen.
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
graph_def, {"c:0"}, "assign", &saved_model_bundle)); graph_def, {"c:0"}, "assign", &saved_model_bundle));

View File

@ -7,6 +7,9 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow:tensorflow.bzl", "filegroup")
@ -283,6 +286,7 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -291,7 +295,7 @@ cc_library(
# Header-only version of "flags" library, for linking from the shared object # Header-only version of "flags" library, for linking from the shared object
# without ODR violations. # without ODR violations.
cc_library( cc_library(
name = "flags_headers_only", name = "flags_headers",
hdrs = ["flags.h"], hdrs = ["flags.h"],
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
@ -302,6 +306,11 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "flags_headers_only",
deps = [":flags_headers"],
)
cc_library( cc_library(
name = "common", name = "common",
srcs = [ srcs = [
@ -361,6 +370,7 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_context", "//tensorflow/compiler/tf2xla:xla_context",
@ -447,8 +457,8 @@ cc_library(
# Header-only version of "flags" library, for linking from the shared object # Header-only version of "flags" library, for linking from the shared object
# without ODR violations. # without ODR violations.
cc_library( cc_library(
name = "get_compiler_ir_hdrs_only", name = "get_compiler_ir_hdrs",
hdrs = ["get_compiler_ir.h"], textual_hdrs = ["get_compiler_ir.h"],
visibility = [ visibility = [
":internal", ":internal",
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
@ -463,6 +473,23 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "get_compiler_ir_hdrs_only",
deps = [":get_compiler_ir_hdrs"],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
)
cc_library( cc_library(
name = "xla_kernel_creator", name = "xla_kernel_creator",
srcs = [ srcs = [
@ -481,6 +508,7 @@ cc_library(
":flags", ":flags",
":jit_compilation_passes", ":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration", "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
@ -520,8 +548,8 @@ cc_library(
hdrs = ["resource_operation_safety_analysis.h"], hdrs = ["resource_operation_safety_analysis.h"],
deps = [ deps = [
":xla_cluster_util", ":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -692,7 +720,6 @@ cc_library(
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
"//tensorflow/cc:scope", "//tensorflow/cc:scope",
"//tensorflow/cc:scope_internal", "//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:side_effect_util",
@ -705,6 +732,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:union_find",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -732,9 +760,9 @@ cc_library(
deps = [ deps = [
":flags", ":flags",
":xla_activity_proto_cc", ":xla_activity_proto_cc",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
@ -842,9 +870,12 @@ tf_cc_test(
"partially_decluster_pass_test.cc", "partially_decluster_pass_test.cc",
"rearrange_function_argument_pass_test.cc", "rearrange_function_argument_pass_test.cc",
], ],
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value tags = [
# error. # TODO(b/141643254) Re-enable msan after fixing
tags = ["nomsan"] + tf_cuda_tests_tags(), # use-of-uninitialized-value error.
"nomsan",
"no_cuda_asan", # TODO(b/171317460): re-enable.
] + tf_cuda_tests_tags(),
deps = [ deps = [
":common", ":common",
":compilability_check_util", ":compilability_check_util",
@ -965,13 +996,13 @@ cc_library(
":xla_activity_listener", ":xla_activity_listener",
":xla_activity_proto_cc", ":xla_activity_proto_cc",
":xla_cluster_util", ":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:union_find",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:graph", "//tensorflow/core:graph",
@ -1075,15 +1106,3 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
)

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_activity_listener.h"
@ -42,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"

View File

@ -24,11 +24,11 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"

View File

@ -27,11 +27,11 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"

View File

@ -167,8 +167,16 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5; jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags; // The `enable_mlir_bridge` flag allows the user to explicitly request that
mlir_flags->tf_mlir_enable_mlir_bridge = false; // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
//
// The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
// user has made an explicit request. That is, if this variable is set to
// true, the program honors the user's request as per `enable_mlir_bridge`; if
// it's set to false, the default behavior is used (which may run either
// bridge, on a per-graph basis).
bool enable_mlir_bridge = false;
bool enable_mlir_bridge_is_explicit = false;
auto setter_for_jitter_tensor_names = [](string sequence) { auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ','); jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
@ -184,11 +192,11 @@ void AllocateAndParseFlags() {
"XLA clusters."), "XLA clusters."),
Flag("tf_xla_check_cluster_input_numerics", Flag("tf_xla_check_cluster_input_numerics",
&build_ops_flags->tf_xla_check_cluster_input_numerics, &build_ops_flags->tf_xla_check_cluster_input_numerics,
"If true then insert CheckNumerics nodes to to check all cluster " "If true then insert CheckNumerics nodes to check all cluster "
"inputs."), "inputs."),
Flag("tf_xla_check_cluster_output_numerics", Flag("tf_xla_check_cluster_output_numerics",
&build_ops_flags->tf_xla_check_cluster_output_numerics, &build_ops_flags->tf_xla_check_cluster_output_numerics,
"If true then insert CheckNumerics nodes to to check all cluster " "If true then insert CheckNumerics nodes to check all cluster "
"outputs."), "outputs."),
Flag("tf_xla_disable_constant_folding", Flag("tf_xla_disable_constant_folding",
&build_ops_flags->tf_xla_disable_constant_folding, &build_ops_flags->tf_xla_disable_constant_folding,
@ -217,12 +225,24 @@ void AllocateAndParseFlags() {
"The amount of jitter to introduce. This amount is added to each " "The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names."), "element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge", Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
&mlir_flags->tf_mlir_enable_mlir_bridge, "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")}); &enable_mlir_bridge_is_explicit)});
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
mlir_flags = new MlirCommonFlags;
if (!enable_mlir_bridge_is_explicit) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
} else if (enable_mlir_bridge) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
} else {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
}
} }
} // namespace } // namespace

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow { namespace tensorflow {
@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags {
// Flags for common MLIR configurations. // Flags for common MLIR configurations.
struct MlirCommonFlags { struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge; ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
}; };
// Return a pointer to the DumpGraphFlags struct; // Return a pointer to the DumpGraphFlags struct;

View File

@ -30,12 +30,12 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
@ -1801,11 +1801,11 @@ absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
"Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze", "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
"Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/, "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
"BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2", "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", "SplitV", "StridedSlice", "StridedSliceGrad",
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex", "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
"TensorStridedSliceUpdate", "Unpack", "DeviceIndex", "TensorStridedSliceUpdate",
}}}; }}};
// clang-format on // clang-format on
return result; return result;
@ -2061,6 +2061,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSelfAdjointEig", "XlaSelfAdjointEig",
"XlaSend", "XlaSend",
"XlaSetBound", "XlaSetBound",
"XlaSetDynamicDimensionSize",
"XlaSharding", "XlaSharding",
"XlaSort", "XlaSort",
"XlaSpmdFullToShardShape", "XlaSpmdFullToShardShape",

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ #ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ #define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <numeric> #include <numeric>
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
#include "absl/base/call_once.h" #include "absl/base/call_once.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -173,7 +174,8 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_result_layout(result.xla_output_shape); build_options.set_result_layout(result.xla_output_shape);
build_options.set_device_allocator(options.device_allocator); build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params); build_options.set_alias_passthrough_params(options.alias_passthrough_params);
build_options.mutable_debug_options()->set_xla_detailed_logging(
options.detailed_logging);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto executables, auto executables,
client_->Compile(*result.computation, argument_layouts, build_options)); client_->Compile(*result.computation, argument_layouts, build_options));
@ -283,14 +285,12 @@ Status XlaCompilationCache::CompileSingleOp(
const NodeDef& node_def = ctx->op_kernel().def(); const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
// TODO(b/155596779): Support TensorList args.
bool has_tensor_list_arg =
absl::c_any_of(args, [](const XlaCompiler::Argument arg) {
return arg.kind == XlaCompiler::Argument::kTensorList;
});
const ConfigProto* config = ctx->function_library()->config_proto(); const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge() && // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
!has_tensor_list_arg; bool use_mlir = config &&
GetMlirBridgeRolloutPolicy(*config) ==
MlirBridgeRolloutPolicy::kEnabledByUser &&
node_def.op() != "VarIsInitializedOp";
#ifdef LIBTPU_ON_GCE #ifdef LIBTPU_ON_GCE
if (use_mlir) { if (use_mlir) {
LOG(WARNING) << "MLIR is not supported in this environment."; LOG(WARNING) << "MLIR is not supported in this environment.";

View File

@ -78,7 +78,9 @@ TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) {
absl::StrContains(status.error_message(), "XLA compilation disabled")); absl::StrContains(status.error_message(), "XLA compilation disabled"));
} }
static void BM_BuildSignature(int iters, int n_args) { void BM_BuildSignature(::testing::benchmark::State& state) {
const int n_args = state.range(0);
NameAttrList fn; NameAttrList fn;
fn.set_name("afunction"); fn.set_name("afunction");
for (int i = 0; i < n_args; i++) { for (int i = 0; i < n_args; i++) {
@ -93,7 +95,7 @@ static void BM_BuildSignature(int iters, int n_args) {
args[i].constant_value = Tensor(DT_INT32, {4, 0}); args[i].constant_value = Tensor(DT_INT32, {4, 0});
} }
while (--iters > 0) { for (auto i : state) {
xla::StatusOr<XlaCompilationCache::Signature> s = xla::StatusOr<XlaCompilationCache::Signature> s =
XlaCompilationCache::BuildSignature(fn, args); XlaCompilationCache::BuildSignature(fn, args);
CHECK(s.ok()); CHECK(s.ok());

View File

@ -132,7 +132,8 @@ Status XlaCompileOnDemandOp::Compile(
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_, platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter); /*has_ref_vars=*/true, &tf_allocator_adapter);
// No detailed logging from on demand op.
options.detailed_logging = false;
XlaCompiler::CompileOptions compile_options; XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true; compile_options.is_entry_computation = true;
// Optimization: where possible, have the computation return a naked array // Optimization: where possible, have the computation return a naked array

View File

@ -398,7 +398,7 @@ static void ShowXlaDeviceDeprecationWarning(
absl::call_once(once, [] { absl::call_once(once, [] {
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be " LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either " "removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile " "@tf.function(jit_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 " "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation."; "for auto-clustering best-effort compilation.";
}); });

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h" #include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
@ -89,7 +90,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
XlaOpRegistry::RegisterCompilationKernels(); XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled. // Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(absl::nullopt);
if (policy == MlirBridgeRolloutPolicy::kDisabledByUser ||
policy == MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo> std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>

View File

@ -426,7 +426,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
ShapedBuffer buffer( ShapedBuffer buffer(
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}), xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
output.platform(), output.device_ordinal()); output.device_ordinal());
buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(), buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
/*source_base_index=*/{}, /*source_base_index=*/{},
/*target_base_index=*/{0}); /*target_base_index=*/{0});
@ -583,7 +583,11 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
XlaCompiler::Argument& arg = out[input_num]; XlaCompiler::Argument& arg = out[input_num];
if (absl::c_binary_search(must_be_constant_idxs, input_num)) { if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
// Handles compile-time constants. // Handles compile-time constants.
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
// TODO(b/157241314): Support constants located in resource variables.
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
<< "tf2xla bridge does not support must-be-constants located in "
"resource variables; try moving them to a tensor";
arg.kind = XlaCompiler::Argument::kConstant; arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype(); arg.type = input->dtype();
arg.shape = input->shape(); arg.shape = input->shape();

View File

@ -191,6 +191,18 @@ tf_cc_binary(
], ],
) )
cc_library(
name = "mlir_bridge_rollout_policy",
srcs = ["mlir_bridge_rollout_policy.cc"],
hdrs = ["mlir_bridge_rollout_policy.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)
filegroup( filegroup(
name = "litfiles", name = "litfiles",
srcs = glob(["runlit*py"]), srcs = glob(["runlit*py"]),

View File

@ -512,10 +512,20 @@ cc_library(
":lhlo", ":lhlo",
":map_hlo_to_lhlo_op", ":map_hlo_to_lhlo_op",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
) )
cc_library(
name = "map_chlo_to_hlo_op",
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"],
deps = [
":hlo",
"@llvm-project//mlir:IR",
],
)
cc_library( cc_library(
name = "map_hlo_to_lhlo_op", name = "map_hlo_to_lhlo_op",
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"],
@ -537,6 +547,7 @@ cc_library(
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TransformUtils",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -557,22 +568,6 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "lhlo_legalize_to_llvm",
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"],
deps = [
":lhlo",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "legalize_to_linalg", name = "legalize_to_linalg",
srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"], srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"],
@ -604,9 +599,11 @@ cc_library(
], ],
deps = [ deps = [
":hlo", ":hlo",
":map_chlo_to_hlo_op",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape", "@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",
@ -649,6 +646,7 @@ cc_library(
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:ViewLikeInterface",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -762,6 +760,7 @@ cc_library(
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -889,6 +888,7 @@ cc_library(
deps = [ deps = [
":chlo_legalize_to_hlo_inc_gen", ":chlo_legalize_to_hlo_inc_gen",
":hlo", ":hlo",
":map_chlo_to_hlo_op",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape", "@llvm-project//mlir:Shape",
@ -936,7 +936,6 @@ cc_library(
srcs = [ srcs = [
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
"lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc", "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc",
"lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc",
"lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc", "lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc",
"lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc", "lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc",
"lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc", "lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc",
@ -946,7 +945,6 @@ cc_library(
":chlo_legalize_to_hlo", # build-cleaner: keep ":chlo_legalize_to_hlo", # build-cleaner: keep
":hlo", ":hlo",
":lhlo", ":lhlo",
":lhlo_legalize_to_llvm", # build-cleaner: keep
":materialize_broadcasts", # build-cleaner: keep ":materialize_broadcasts", # build-cleaner: keep
":pass_details", ":pass_details",
":unfuse_batch_norm", # build-cleaner: keep ":unfuse_batch_norm", # build-cleaner: keep

View File

@ -41,6 +41,8 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
# Options and settings # Options and settings
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
option(MHLO_BUILD_EMBEDDED "Build MHLO as part of another project" OFF)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# MSVC defaults # MSVC defaults
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
@ -57,11 +59,16 @@ endif()
# MLIR/LLVM Configuration # MLIR/LLVM Configuration
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
find_package(MLIR REQUIRED CONFIG) # Find MLIR to install if we are building standalone. If building as part of
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") # another project, let it handle the MLIR dependency. The dependent project
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") # might use a bundled version of MLIR instead of installing, for instance.
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") if(NOT MHLO_BUILD_EMBEDDED)
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") find_package(MLIR REQUIRED CONFIG)
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
endif()
if(LLVM_ENABLE_ZLIB) if(LLVM_ENABLE_ZLIB)
find_package(ZLIB) find_package(ZLIB)

View File

@ -22,7 +22,7 @@ upstream.
## QuickStart: building and testing ## QuickStart: building and testing
These instructions work on Linux, you may have to adjust for your plaform. These instructions work on Linux, you may have to adjust for your platform.
To build the code in this repository, you need a clone of the LLVM/MLIR git To build the code in this repository, you need a clone of the LLVM/MLIR git
repository: repository:

View File

@ -89,10 +89,9 @@ class HLOClient_BroadcastBinaryElementwiseOp<
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
); );
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &result, Value left, Value right, " OpBuilderDAG<(ins "Value":$left, "Value":$right,
"DenseIntElementsAttr broadcast_dimensions" "DenseIntElementsAttr":$broadcast_dimensions)>];
>];
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
@ -427,7 +426,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
string summary = "Compare operator (with optional broadcasting)"; string summary = "Compare operator (with optional broadcasting)";
string description = [{ string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`. Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
@ -437,14 +439,15 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
HLO_Tensor:$lhs, HLO_Tensor:$lhs,
HLO_Tensor:$rhs, HLO_Tensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions, OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
); );
let results = (outs HLO_PredTensor); let results = (outs HLO_PredTensor);
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " OpBuilderDAG<(ins "Value":$lhs, "Value":$rhs,
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" "DenseIntElementsAttr":$broadcast_dimensions,
>]; "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>];
} }
#endif // CHLO_OPS #endif // CHLO_OPS

View File

@ -58,9 +58,8 @@ def HLO_ConstOp : HLO_Op<"constant",
HLO_StaticShapeTensor:$output HLO_StaticShapeTensor:$output
); );
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &result, Attribute value" OpBuilderDAG<(ins "Attribute":$value)>];
>];
let assemblyFormat = "attr-dict $value"; let assemblyFormat = "attr-dict $value";
@ -147,9 +146,8 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape], [NoSideEffect, SameOperandsAndResultShape],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
let builders = [OpBuilder< let builders = [
"Value operand" OpBuilderDAG<(ins "Value":$operand)>];
>];
} }
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
@ -162,9 +160,8 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp<
"convert", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, "convert", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
BASE_HLO_ConvertOp { BASE_HLO_ConvertOp {
let builders = [OpBuilder< let builders = [
"Value operand, Type result_element_ty" OpBuilderDAG<(ins "Value":$operand, "Type":$result_element_ty)>];
>];
let hasFolder = 1; let hasFolder = 1;
@ -241,7 +238,9 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
} }
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp {
let hasFolder = 1;
}
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
@ -616,10 +615,9 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
let results = (outs Variadic<HLO_TensorOrTuple>); let results = (outs Variadic<HLO_TensorOrTuple>);
let builders = [OpBuilder< let builders = [
"OpBuilder &, OperationState &state, ValueRange operands, " OpBuilderDAG<(ins "ValueRange":$operands, "ValueRange":$init_values,
"ValueRange init_values, DenseIntElementsAttr dimensions" "DenseIntElementsAttr":$dimensions)>];
>];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
bool isFusibleWithConsumer() { bool isFusibleWithConsumer() {
@ -655,18 +653,16 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
let hasFolder = 1; let hasFolder = 1;
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &results, " OpBuilderDAG<(ins "Value":$value, "int32_t":$index)>];
"Value value, int32_t index">];
} }
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val); let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
let results = (outs HLO_Tuple); let results = (outs HLO_Tuple);
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &results, " OpBuilderDAG<(ins "ValueRange":$values)>];
"ValueRange values">];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
@ -678,16 +674,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
let arguments = (ins let arguments = (ins
HLO_Tensor:$lhs, HLO_Tensor:$lhs,
HLO_Tensor:$rhs, HLO_Tensor:$rhs,
HLO_ComparisonDirectionAttr:$comparison_direction HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
); );
let results = (outs HLO_PredTensor); let results = (outs HLO_PredTensor);
let hasFolder = 1; let hasFolder = 1;
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " OpBuilderDAG<(ins "Value":$lhs, "Value":$rhs,
"StringAttr comparison_direction" "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>,
>]; ];
let hasCustomHLOConverter = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -697,7 +696,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
def HLO_SliceOp: HLO_Op< def HLO_SliceOp: HLO_Op<
"slice", "slice",
[NoSideEffect, SameOperandsAndResultElementType, [NoSideEffect, SameOperandsAndResultElementType,
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { AllTypesMatch<["start_indices", "limit_indices", "strides"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
I64ElementsAttr:$start_indices, I64ElementsAttr:$start_indices,
@ -709,21 +709,6 @@ def HLO_SliceOp: HLO_Op<
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value operand, "
"DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
"DenseIntElementsAttr strides"
>];
let extraClassDeclaration = [{
// Infers output type for given operand and attributes. Result type is
// unranked if any of the attributes is illegal.
static Type InferOutputTypes(Builder *builder, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides);
}];
} }
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
@ -1032,7 +1017,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
BASE_HLO_GetDimensionSizeOp { BASE_HLO_GetDimensionSizeOp {
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
I32Attr:$dimension I64Attr:$dimension
); );
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
// XLA semantics is available. This limitation is because of the current XLA // XLA semantics is available. This limitation is because of the current XLA
@ -1144,9 +1129,11 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
I32Tensor:$size, I32Tensor:$size,
I32Attr:$dimension I64Attr:$dimension
); );
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
let hasFolder = 1;
} }
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp { def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
@ -1160,10 +1147,9 @@ def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShap
let regions = (region SizedRegion<1>:$comparator); let regions = (region SizedRegion<1>:$comparator);
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &state, ValueRange operands, " OpBuilderDAG<(ins "ValueRange":$operands, CArg<"int64_t", "-1">:$dimension,
"int64_t dimension = -1, bool is_stable = false" CArg<"bool", "false">:$is_stable)>];
>];
// TODO(b/129422361): SortOp has special conversion logic to HLO. // TODO(b/129422361): SortOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
@ -1199,6 +1185,8 @@ def HLO_PadOp: HLO_Op<"pad",
// TODO(b/129422361): PadOp has a custom constructor for HLO. // TODO(b/129422361): PadOp has a custom constructor for HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
let hasFolder = 1;
} }
def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp {

View File

@ -749,11 +749,30 @@ def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
HLO_COMPARISON_DIRECTION_LT HLO_COMPARISON_DIRECTION_LT
]>; ]>;
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
"Which comparison type to use.",
[
HLO_COMPARISON_TYPE_FLOAT,
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
HLO_COMPARISON_TYPE_SIGNED,
HLO_COMPARISON_TYPE_UNSIGNED
]>;
class BASE_HLO_CompareOp { class BASE_HLO_CompareOp {
string summary = "Comparison operator"; string summary = "Comparison operator";
string description = [{ string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`. Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
@ -1363,7 +1382,7 @@ class BASE_HLO_BitcastOp {
string description = [{ string description = [{
This op changes the shape of the input in the way that the physical This op changes the shape of the input in the way that the physical
arranggment of elements are unchanged. arrangement of elements are unchanged.
However, the op needs layout information to make sense of "physical However, the op needs layout information to make sense of "physical
arrangement of elements". Layout support in MHLO is currently under arrangement of elements". Layout support in MHLO is currently under

View File

@ -284,7 +284,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out, Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions, OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
); );
} }
@ -313,167 +314,6 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
); );
} }
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = [{
modifies the offset, sizes and strides of a statically shaped memref
}];
let description = [{
Casts the statically shaped memref operand to a memref with optionally
modified offsets, sizes and strides.
Example:
```mlir
%buf_transformed =
lmhlo.static_memref_cast %buf
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
// offset 2.
```
}];
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [OpBuilder<"MemRefType resultType, Value operand",
[{
$_state.addOperands(operand);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
[SameVariadicOperandSize, NoSideEffect,
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = "dynamic memref cast operation";
let description = [{
Change sizes and strides of a memref using the values computed in runtime.
Example:
```mlir
%buf_transformed =
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
// from the input.
```
}];
let arguments = (ins
Arg<LHLO_Buffer, "", []>:$operand,
Variadic<Index>:$sizes,
Variadic<Index>:$strides
);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [
OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, "
"ValueRange strides", [{
$_state.addOperands(operand);
$_state.addOperands(sizes);
$_state.addOperands(strides);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
type($result)
}];
}
//===----------------------------------------------------------------------===//
// ReshapeMemRefCastOp
//===----------------------------------------------------------------------===//
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
NoSideEffect]> {
let summary = "reshape memref cast operation";
let description = [{
The `reshape_memref_cast` operation converts a memref from one type to an
equivalent type with a provided shape. The data is never copied or moved.
The source and destination types are compatible if both have the same
element type, address space and identity layout map. The following
combinations are possible:
a. Both are ranked memref types.
```mlir
// Reshape statically-shaped memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
%dst0 = reshape_memref_cast %src(%shape0)
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
```
b. Source type is ranked, destination type is unranked.
```mlir
// Reshape dynamically-shaped 1D memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
```
c. Source type is unranked, destination type is ranked.
```mlir
// Flatten unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
```
d. Both are unranked memref types.
```mlir
// Reshape unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
```
}];
let arguments = (ins
AnyRankedOrUnrankedMemRef:$operand,
LHLO_ExtentBuffer:$shape
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);
let extraClassDeclaration = [{
BaseMemRefType getType() {
return getResult().getType().cast<BaseMemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
`)` `->` type($result)
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LMHLO Other op definitions. // LMHLO Other op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -602,11 +442,8 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices, Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
I64Attr:$index_vector_dim, GatherDimensionNumbers:$dimension_numbers,
I64ElementsAttr:$offset_dims,
I64ElementsAttr:$slice_sizes, I64ElementsAttr:$slice_sizes,
I64ElementsAttr:$collapsed_slice_dims,
I64ElementsAttr:$start_index_map,
Arg<LHLO_Buffer, "", [MemWrite]>:$output Arg<LHLO_Buffer, "", [MemWrite]>:$output
); );
} }
@ -815,7 +652,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [
OpBuilder<"ArrayRef<NamedAttribute> attributes"> OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
]; ];
} }
@ -825,9 +662,9 @@ def TerminatorOp :
let description = [{ let description = [{
Terminator operation for the LHLO dialect. Terminator operation for the LHLO dialect.
}]; }];
let builders = [OpBuilder<"ValueRange operands", let builders = [
[{ build($_builder, $_state, llvm::None, operands, llvm::None); }] OpBuilderDAG<(ins "ValueRange":$operands),
>]; [{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>];
} }
#endif // LHLO_OPS #endif // LHLO_OPS

View File

@ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
} }
def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> {
let summary = "Legalize from LHLO dialect to LLVM.";
let constructor = "createTestLhloToLLVMPass()";
}
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
let summary = "Legalize from LHLO dialect to parallel loops."; let summary = "Legalize from LHLO dialect to parallel loops.";
let constructor = "createLegalizeLhloToParallelLoopsPass()"; let constructor = "createLegalizeLhloToParallelLoopsPass()";

View File

@ -0,0 +1,97 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
#include <type_traits>
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace chlo {
struct HloComplexAdaptor {
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
template <typename FromOpTy, typename ToOpTy>
struct HloBinaryElementwiseAdaptor {
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction(), from_op.compare_typeAttr());
}
};
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
template <template <typename, typename, typename> class Pattern,
typename... ConstructorArgs>
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns,
ConstructorArgs &&...args) {
#define POPULATE_BCAST(ChloOp, HloOp) \
patterns->insert< \
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
context, args...);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
// Broadcasting ops requiring special construction.
patterns
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
context, args...);
patterns
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
context, args...);
#undef POPULATE_BCAST
}
} // namespace chlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -96,7 +97,7 @@ template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> { struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types, Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<SupportedType>()) { if (element_type.isa<SupportedType>()) {
return b->template create<StdScalarOp>(loc, result_types, args, return b->template create<StdScalarOp>(loc, result_types, args,
mlir::None); mlir::None);
@ -120,7 +121,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
@ -130,8 +131,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
Value lhs = args[0]; Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>(); auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval = Value zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
}
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge, auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
lhs, zero_intval); lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
@ -196,7 +200,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0]; const auto& lhs = args[0];
const auto& rhs = args[1]; const auto& rhs = args[1];
Type element_type = lhs.getType(); Type element_type = getElementTypeOrSelf(lhs.getType());
if (element_type.isSignlessInteger()) { if (element_type.isSignlessInteger()) {
Optional<CmpIPredicate> predicate = Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(comparison_direction); getCmpPredicate<CmpIPredicate>(comparison_direction);
@ -268,8 +272,8 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type sourceType = args.front().getType(); Type sourceType = getElementTypeOrSelf(args.front().getType());
Type targetType = result_types.front(); Type targetType = getElementTypeOrSelf(result_types.front());
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None); return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
@ -390,7 +394,7 @@ struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
static Value map(Location loc, StringRef comparison_direction, static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<SupportedType>()) { if (element_type.isa<SupportedType>()) {
auto predicate = getCmpPredicate<Predicate>(comparison_direction); auto predicate = getCmpPredicate<Predicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction"); assert(predicate.hasValue() && "expected valid comparison direction");
@ -439,7 +443,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
@ -449,8 +453,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
Value lhs = args[0]; Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>(); auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval = Value zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
}
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
} }
return nullptr; return nullptr;
@ -461,11 +468,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto integer_type = element_type.dyn_cast<IntegerType>()) { if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// lmhlo.not(x) -> x ^ -1 // lmhlo.not(x) -> x ^ -1
auto all_ones = Value all_ones =
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones);
}
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
} }
return nullptr; return nullptr;
@ -493,26 +503,35 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto float_type = element_type.dyn_cast<FloatType>()) { if (auto float_type = element_type.dyn_cast<FloatType>()) {
bool ignored; bool ignored;
APFloat one_apfloat(1.0f); APFloat one_apfloat(1.0f);
one_apfloat.convert(float_type.getFloatSemantics(), one_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored); APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type); Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
}
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) { } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero = Value zero =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
loc, integer_type.getWidth() - 1, integer_type.getWidth()); loc, integer_type.getWidth() - 1, integer_type.getWidth());
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value one = Value one =
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
bitwidth_minus_one =
b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one);
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
}
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
} }
@ -583,6 +602,27 @@ struct HloOpToStdScalarOp {
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b); op.getLoc(), comparison_direction, result_types, args, b);
} }
// Implementation for LHLO ops except lmhlo::CompareOp.
template <typename LhloOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b);
}
// Implementation for lmhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, lmhlo::CompareOp>::value>>
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
loc, comparison_direction, result_types, args, b);
}
}; };
} // namespace lmhlo } // namespace lmhlo

View File

@ -48,12 +48,8 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(); std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true, /// buffers if necessary.
/// allocated buffers for function results will be returned and escape the std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
/// function. Otherwise, the signature is rewritten with extra arguments for the
/// buffers that are to be used for results.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_functions = false);
// Lowers from HLO dialect to Linalg dialect. // Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();

View File

@ -35,8 +35,6 @@ inline void registerAllMhloPasses() { registerMHLOPasses(); }
namespace lmhlo { namespace lmhlo {
std::unique_ptr<Pass> createTestLhloToLLVMPass();
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"

View File

@ -24,10 +24,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
class LLVMTypeConverter;
class LowerToLLVMOptions;
class OwningRewritePatternList; class OwningRewritePatternList;
class BufferAssignmentPlacer;
// Populates a collection of rewrite patterns to realize element-wise operations // Populates a collection of rewrite patterns to realize element-wise operations
// on ranked tensors where possible. // on ranked tensors where possible.
@ -95,14 +92,6 @@ void PopulateTrigonometricToApproximationPatterns(
} // namespace mhlo } // namespace mhlo
namespace lmhlo {
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
OwningRewritePatternList *patterns);
} // namespace lmhlo
namespace chlo { namespace chlo {
// Populates a collection of conversion patterns for legalizing client-HLO to // Populates a collection of conversion patterns for legalizing client-HLO to

View File

@ -190,11 +190,12 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
Value lhs, Value rhs, Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dimensions, DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) { StringAttr comparison_direction,
StringAttr compare_type) {
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
builder.getI1Type(), broadcast_dimensions); builder.getI1Type(), broadcast_dimensions);
build(builder, result, new_type, lhs, rhs, broadcast_dimensions, build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
comparison_direction); comparison_direction, compare_type);
} }
LogicalResult BroadcastCompareOp::inferReturnTypeComponents( LogicalResult BroadcastCompareOp::inferReturnTypeComponents(

View File

@ -86,6 +86,26 @@ namespace {
// Utilities for the canonicalize patterns // Utilities for the canonicalize patterns
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Verifies that dimension attribute for the op correctly indexes in operand or
// result shape.
template <typename OpT>
static LogicalResult VerifyDimAttr(OpT op) {
int64_t rank = -1;
if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
rank = ty.getRank();
} else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
rank = ty.getRank();
} else {
return success();
}
int64_t dim = op.dimension();
if (dim < 0 || dim >= rank)
return op.emitOpError() << "requires dimension attribute in range [0, "
<< rank << "); found (" << dim << ")";
return success();
}
// Returns 1D 64-bit dense elements attribute with the given values. // Returns 1D 64-bit dense elements attribute with the given values.
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values, DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
Builder* builder) { Builder* builder) {
@ -245,10 +265,14 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GetDimensionSizeOp // GetDimensionSizeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
//
static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
/// Fold get_dimension_size when the said shape dimension is a constant. /// Fold get_dimension_size when the said shape dimension is a constant.
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) { OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
RankedTensorType type = operand().getType().cast<RankedTensorType>(); RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
if (!type) return {};
int32_t dim = dimension(); int32_t dim = dimension();
if (type.isDynamic(dim)) return {}; if (type.isDynamic(dim)) return {};
// The result type is always is a 0-d i32 tensor. // The result type is always is a 0-d i32 tensor.
@ -741,7 +765,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
if (dimIndex >= resultRank) { if (dimIndex >= resultRank) {
return op.emitOpError( return op.emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for " llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result result with rank {1}", "result with rank {1}",
dimIndex, resultRank)); dimIndex, resultRank));
} }
@ -828,7 +852,7 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
if (dimIndex >= resultRank) { if (dimIndex >= resultRank) {
return op.emitOpError( return op.emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for " llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result result with rank {1}", "result with rank {1}",
dimIndex, resultRank)); dimIndex, resultRank));
} }
@ -1053,6 +1077,9 @@ LogicalResult ConcatenateOp::inferReturnTypes(
return success(); return success();
} }
if (first_type.getRank() == 0)
return emitOptionalError(location, "rank-0 values cannot be concatenated");
auto out_shape = llvm::to_vector<6>(first_type.getShape()); auto out_shape = llvm::to_vector<6>(first_type.getShape());
// Determine what the non-concatenate dimensions should be. // Determine what the non-concatenate dimensions should be.
@ -1721,6 +1748,35 @@ LogicalResult SelectOp::reifyReturnTypeShapes(
&reifiedReturnShapes); &reifiedReturnShapes);
} }
//===----------------------------------------------------------------------===//
// SetDimensionSizeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SetDimensionSizeOp op) {
if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
if (size.getRank() != 0)
return op.emitOpError() << "size operand should be of rank-0";
}
return VerifyDimAttr(op);
}
OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
if (input) return input;
DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!size || !size.isSplat()) return {};
auto ty = getType().dyn_cast<RankedTensorType>();
if (!ty) return {};
int64_t dim_size = ty.getDimSize(dimension());
if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
return operand();
return {};
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PadOp // PadOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1784,6 +1840,61 @@ static LogicalResult Verify(PadOp op) {
return success(); return success();
} }
OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// If all padding is zero then it is an identity pad.
auto is_zero = [](const APInt& i) { return i == 0; };
if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) &&
llvm::all_of(edge_padding_high().getIntValues(), is_zero) &&
llvm::all_of(interior_padding().getIntValues(), is_zero))
return operand();
// If any padding is negative then it isn't supported by the folder (yet).
auto is_negative = [](const APInt& i) { return i.slt(0); };
if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) &&
llvm::all_of(edge_padding_high().getIntValues(), is_negative) &&
llvm::all_of(interior_padding().getIntValues(), is_negative))
return {};
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
if (!input || !input.getType().hasRank() || !padding || !return_type ||
!return_type.hasStaticShape())
return {};
// Fill the full result tensor with the padding value.
llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
padding.getValue({}));
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
++index[i];
if (index[i] < shape[i]) return true;
index[i] = 0;
}
return false;
};
// Iterate over all elements of the input tensor and copy it to the correct
// location in the output tensor.
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
do {
uint64_t linear_index = 0;
uint64_t linear_index_multiplyer = 1;
for (int64_t i = index.size() - 1; i >= 0; --i) {
linear_index +=
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
index[i] *
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
linear_index_multiplyer;
linear_index_multiplyer *= return_type.getShape()[i];
}
result[linear_index] = input.getValue(index);
} while (next_index(index, input.getType().getShape()));
return DenseElementsAttr::get(return_type, result);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReshapeOp // ReshapeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1930,6 +2041,14 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return DenseElementsAttr::get(type, values); return DenseElementsAttr::get(type, values);
} }
struct round {
APFloat operator()(const APFloat& f) {
APFloat r = f;
r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
return r;
}
};
#define UNARY_FOLDER(Op, Func) \ #define UNARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \ OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \ if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
@ -1939,7 +2058,15 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return {}; \ return {}; \
} }
#define UNARY_FOLDER_FLOAT(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
return {}; \
}
UNARY_FOLDER(NegOp, std::negate); UNARY_FOLDER(NegOp, std::negate);
UNARY_FOLDER_FLOAT(RoundOp, round);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BinaryOps // BinaryOps
@ -2068,14 +2195,77 @@ BINARY_FOLDER(MinOp, min);
// SliceOp // SliceOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand, // Returns output dimension size for slice result for the given arguments.
DenseIntElementsAttr start_indices, // Returns -1 if arguments are illegal.
DenseIntElementsAttr limit_indices, static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
DenseIntElementsAttr strides) { int64_t stride) {
return build(builder, result, if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
InferOutputTypes(&builder, operand, start_indices, limit_indices, stride == 0)
strides), return -1;
operand, start_indices, limit_indices, strides);
return llvm::divideCeil(end - start, stride);
}
LogicalResult SliceOp::inferReturnTypes(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
SliceOpAdaptor slice(operands, attributes);
// TODO(jpienaar): Update this code after refactoring verify.
if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
return failure();
}
Type ty = slice.operand().getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) {
// The operand type is unranked, so the best we can infer for the result
// type is an unranked tensor with the same element type as the operand
// type.
inferredReturnTypes.assign({ty});
return success();
}
ShapedType attr_ty = slice.start_indices().getType();
if (attr_ty.getRank() != 1) {
return emitOptionalError(location, "start_indices has rank ",
attr_ty.getRank(), " instead of required rank 1");
}
int64_t rank = ranked_ty.getRank();
if (attr_ty.getNumElements() != rank) {
return emitOptionalError(
location, "the number of elements in start_indices (",
attr_ty.getNumElements(), ") does not match the rank of the operand (",
rank, ")");
}
if (!attr_ty.getElementType().isSignlessInteger(64) ||
slice.limit_indices().getType() != attr_ty ||
slice.strides().getType() != attr_ty) {
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
// having been verified at this point. Emit an error message that matches
// the one that would be reported by AllTypesMatch for a more consistent
// user experience.
// TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
return emitOptionalError(location,
"failed to verify that all of {start_indices, "
"limit_indices, strides} have same type");
}
SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
SmallVector<int64_t, 4> shape;
shape.reserve(rank);
for (int64_t i = 0, e = rank; i != e; i++) {
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
stride_vals[i]));
}
inferredReturnTypes.assign(
{RankedTensorType::get(shape, ranked_ty.getElementType())});
return success();
} }
template <typename I, typename E> template <typename I, typename E>
@ -2258,46 +2448,6 @@ void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<SimplifyConcatSlice>(context); results.insert<SimplifyConcatSlice>(context);
} }
// Returns output dimension size for slice result for the given arguments.
// Returns -1 if arguments are illegal.
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
int64_t stride) {
if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
stride == 0)
return -1;
return llvm::divideCeil(end - start, stride);
}
Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
Type ty = operand.getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return ty;
int64_t rank = ranked_ty.getRank();
// Illegal attributes.
ShapedType attr_ty = start_indices.getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
!attr_ty.getElementType().isSignlessInteger(64) ||
limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
return ty;
SmallVector<int64_t, 4> start(start_indices.getValues<int64_t>());
SmallVector<int64_t, 4> limit(limit_indices.getValues<int64_t>());
SmallVector<int64_t, 4> stride_vals(strides.getValues<int64_t>());
SmallVector<int64_t, 4> shape;
shape.reserve(rank);
for (int64_t i = 0, e = rank; i != e; i++) {
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
stride_vals[i]));
}
return RankedTensorType::get(shape, ranked_ty.getElementType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SortOp // SortOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2525,10 +2675,12 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
Value rhs, StringAttr comparison_direction) { Value rhs, StringAttr comparison_direction,
StringAttr compare_type) {
auto new_type = auto new_type =
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
build(builder, result, new_type, lhs, rhs, comparison_direction); build(builder, result, new_type, lhs, rhs, comparison_direction,
compare_type);
} }
LogicalResult CompareOp::inferReturnTypeComponents( LogicalResult CompareOp::inferReturnTypeComponents(
@ -2799,15 +2951,22 @@ namespace mhlo {
namespace { namespace {
struct HLOInlinerInterface : public DialectInlinerInterface { struct HLOInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface; using DialectInlinerInterface::DialectInlinerInterface;
// Allow all call operations to be inlined.
bool isLegalToInline(Operation* call, Operation* callable,
bool wouldBeCloned) const final {
return true;
}
// We don't have any special restrictions on what can be inlined into // We don't have any special restrictions on what can be inlined into
// destination regions (e.g. while/conditional bodies). Always allow it. // destination regions (e.g. while/conditional bodies). Always allow it.
bool isLegalToInline(Region* dest, Region* src, bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
BlockAndValueMapping& valueMapping) const final { BlockAndValueMapping& valueMapping) const final {
return true; return true;
} }
// Operations in mhlo dialect are always legal to inline since they are // Operations in mhlo dialect are always legal to inline since they are
// pure. // pure.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final { bool isLegalToInline(Operation*, Region*, bool,
BlockAndValueMapping&) const final {
return true; return true;
} }
}; };

View File

@ -18,15 +18,13 @@ limitations under the License.
include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
// Canonicalization patterns. // Canonicalization patterns.
def DynamicBroadcastToOwnShape_1 : Pat< def DynamicBroadcastToOwnShape_1 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0, (HLO_DynamicBroadcastInDimOp:$op $x,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr), (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x)), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; (replaceWithValue $x)>;
def DynamicBroadcastToOwnShape_2 : Pat< def DynamicBroadcastToOwnShape_2 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr), (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; (replaceWithValue $x)>;

View File

@ -88,76 +88,6 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<EraseConstOp>(context); results.insert<EraseConstOp>(context);
} }
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
static LogicalResult Verify(StaticMemRefCastOp op) {
if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
return op.emitOpError("operand must have static shape");
if (!op.getType().hasStaticShape())
return op.emitOpError("result must have static shape");
return success();
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
Value DynamicMemRefCastOp::getViewSource() {
return *getODSOperands(0).begin();
}
static LogicalResult Verify(DynamicMemRefCastOp op) {
// Check if `sizes` and `strides` args are compatible with the result type.
if (op.sizes().size() != op.getType().getRank())
return op.emitOpError(
"`sizes` args count must be equal to the rank of the output memref");
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeMemrefCastOp
//===----------------------------------------------------------------------===//
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
static LogicalResult Verify(ReshapeMemRefCastOp op) {
Type operandType = op.operand().getType();
Type resultType = op.result().getType();
Type operandElementType = operandType.cast<ShapedType>().getElementType();
Type resultElementType = resultType.cast<ShapedType>().getElementType();
if (operandElementType != resultElementType)
return op.emitOpError(
"element types of source and destination memref "
"types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
if (!operandMemRefType.getAffineMaps().empty())
return op.emitOpError(
"operand memref type should have identity affine map");
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) {
if (shapeSize == ShapedType::kDynamicSize)
return op.emitOpError(
"cannot use shape operand with dynamic length to "
"cast statically-ranked memref type");
if (shapeSize != resultMemRefType.getRank())
return op.emitOpError(
"length of shape operand differs from the result's memref rank");
if (!resultMemRefType.getAffineMaps().empty())
return op.emitOpError(
"result memref type should have identity affine map");
}
return success();
}
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -67,6 +67,7 @@ add_mlir_library(MhloPasses
DEPENDS DEPENDS
MLIRhlo_opsIncGen MLIRhlo_opsIncGen
MLIRMhloLowerComplexIncGen MLIRMhloLowerComplexIncGen
MLIRMhloPassIncGen
LINK_COMPONENTS LINK_COMPONENTS
Core Core
@ -133,8 +134,6 @@ add_mlir_library(LmhloPasses
lhlo_fuse_linalg.cc lhlo_fuse_linalg.cc
lhlo_legalize_to_affine.cc lhlo_legalize_to_affine.cc
lhlo_legalize_to_gpu.cc lhlo_legalize_to_gpu.cc
lhlo_legalize_to_llvm.cc
lhlo_legalize_to_llvm_pass.cc
lhlo_legalize_to_parallel_loops.cc lhlo_legalize_to_parallel_loops.cc
DEPENDS DEPENDS

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/broadcast_utils.h" #include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
@ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
// Converts binary ops that statically are determined to not broadcast directly // Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op. // to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor> template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { struct ConvertTrivialNonBroadcastBinaryOp
using OpRewritePattern<ChloOpTy>::OpRewritePattern; : public OpConversionPattern<ChloOpTy> {
LogicalResult matchAndRewrite(ChloOpTy op, using OpConversionPattern<ChloOpTy>::OpConversionPattern;
PatternRewriter &rewriter) const override { LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite for statically determinable non-broadcasting cases. // Only rewrite for statically determinable non-broadcasting cases.
auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>(); typename ChloOpTy::Adaptor transformed(operands);
auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>(); auto lhs_type =
transformed.lhs().getType().template dyn_cast<RankedTensorType>();
auto rhs_type =
transformed.rhs().getType().template dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) return failure(); if (!lhs_type || !rhs_type) return failure();
// Requires rank broadcast. // Requires rank broadcast.
@ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
} }
} }
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), rewriter.replaceOp(
op.lhs(), op.rhs(), rewriter)}); op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
operands[1], rewriter)});
return success(); return success();
} }
}; };
@ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
// `shape.broadcast` op, which only supports prefix-padding. // `shape.broadcast` op, which only supports prefix-padding.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor> template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertRankedDynamicBroadcastBinaryOp struct ConvertRankedDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> { : public OpConversionPattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern; using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(ChloOpTy op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only support ranked operands. // Only support ranked operands.
Value lhs = op.lhs(); typename ChloOpTy::Adaptor transformed(operands);
Value rhs = op.rhs(); Value lhs = transformed.lhs();
Value rhs = transformed.rhs();
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>(); auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>(); auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
auto result_type = auto result_type =
@ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp
} }
}; };
// Converts a broadcasting binary operation with a scalar operand and an
// unranked operand to a ranked broadcasting operation by dynamically reshaping
// the unranked operand to a 1D tensor. This will always be safe because
// broadcasting from a scalar to another shape always works.
template <typename ChloOpTy, typename HloOpTy>
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
bool lhs_is_scalar = lhs_ranked_type &&
lhs_ranked_type.getShape().empty() &&
rhs_unranked_type;
bool rhs_is_scalar = rhs_ranked_type &&
rhs_ranked_type.getShape().empty() &&
lhs_unranked_type;
// Only support the case where exactly one operand is scalar and the other
// is unranked. Other patterns in this file will create more efficient
// lowerings for cases where both ranks are known or will handle the more
// generic case of both inputs being unranked.
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value size_tensor =
rewriter.create<TensorFromElementsOp>(loc, num_elements);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc, RankedTensorType::get({-1}, result_type.getElementType()),
lhs_is_scalar ? rhs : lhs, size_tensor);
// Create a new ranked Chlo op that will be further lowered by other
// patterns into Mhlo.
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
rhs_is_scalar ? rhs : reshaped};
Value computed = rewriter.create<ChloOpTy>(
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
// Reshape the result back into an unranked tensor.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
computed, shape);
return success();
}
};
// Handles lowering of the following pattern to patterns that will be further
// matched by other patterns until they result in LHLO:
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
//
// The sequence of specializations this handles is:
// - Either operand being scalar
// - Operands having equal shapes
// - The resulting value being any of ranks [2,6]
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Only support unranked operands. If either operand is ranked, another
// pattern will handle the lowering.
if (!lhs_type || !rhs_type) return failure();
// If lhs is scalar
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
op.getAttrs());
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
// If lhs is NOT scalar
//
// See if rhs is scalar
OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
true);
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op.getAttrs());
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
// If NEITHER shape is scalar
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
Value shape_of_lhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
Value shape_of_rhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
loc, result_type, equal_shapes, true);
else_no_scalars_builder.create<scf::YieldOp>(loc,
if_eq_shapes_op.getResult(0));
OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
Value non_broadcast_op =
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes are not scalar, nor equal
//
// See if values are of a rank that we support.
OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
if_neq_shapes_builder.create<scf::YieldOp>(
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
rewriter.replaceOp(op, {if_op.getResult(0)});
return success();
}
private:
// Returns the dyanamic result of checking the given value is a scalar
// tensor.
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
auto loc = op.getLoc();
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
Value rank_tensor = rewriter.create<shape::RankOp>(
loc, rewriter.getIndexType(), shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
rank_tensor,
rewriter.create<ConstantIndexOp>(loc, 0));
}
// Create the if statement and code for a broadcasting op with a result of a
// given rank.
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
Value lhs, Value rhs,
Value actual_rank,
int targeted_rank) const {
auto loc = op.getLoc();
// Create the if block to place the current specialized logic in.
Value greater_rank_is_n = builder.create<CmpIOp>(
loc, CmpIPredicate::eq, actual_rank,
builder.create<ConstantIndexOp>(loc, targeted_rank));
auto if_op =
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
OpBuilder if_builder = if_op.getThenBodyBuilder();
// Handle shape broadcasting and inferrence.
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
auto known_rank_extent_tensor_type =
RankedTensorType::get({targeted_rank}, builder.getIndexType());
auto reshaped_type = RankedTensorType::get(
llvm::SmallVector<int64_t, 6>(targeted_rank,
RankedTensorType::kDynamicSize),
lhs.getType().template dyn_cast<TensorType>().getElementType());
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
loc, known_rank_extent_tensor_type,
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
ranked_shape));
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
nullptr);
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
loc, known_rank_extent_tensor_type, extended_lhs);
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
nullptr);
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
loc, known_rank_extent_tensor_type, extended_rhs);
// 1. Reshape operands to the given rank (with the same number of elements)
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
// can be broadcasted and do the actual broadcasting)
// 3. Type erase the output back to unranked
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, lhs, extended_lhs_casted);
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, rhs, extended_rhs_casted);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{reshaped_type},
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
Value reshaped_result = if_builder.create<TensorCastOp>(
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
if_builder.create<scf::YieldOp>(loc, reshaped_result);
// Return the if_op, so the result can be used and the else block can be
// used for the next rank specialized step.
return if_op;
}
// Iterates over the desired ranks to be specialized and generates the code
// snippet for each case.
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
Value rhs) const {
constexpr int max_rank_specialization = 7;
auto loc = op.getLoc();
// Find the larger rank of the 2 operands.
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value lhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
Value rhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
Value lhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
Value rhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
Value greater_rank_lhs =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
Value greater_rank =
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
// Generate a list of nested if/else statements to handle rank
// specializations from 2-6.
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
rhs, greater_rank, 2);
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder();
for (int i = 3; i < max_rank_specialization; i++) {
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
rhs, greater_rank, i);
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
else_builder = inner_if.getElseBodyBuilder();
}
// Fire an assertion if none of the rank specializations applied (one of the
// ranks was greater than 6).
else_builder.create<AssertOp>(
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
"Input for dynamic binary op lowering was of a rank greater than 6");
else_builder.create<scf::YieldOp>(loc, lhs);
// Return the result of the outermost if statement.
return if_op.getResult(0);
}
};
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
void PopulateForBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns
->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 10);
patterns->insert<
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 5);
patterns->insert<
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context);
}
template <typename FromOpTy, typename ToOpTy>
struct HloBinaryElementwiseAdaptor {
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloComplexAdaptor {
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction());
}
};
#include "generated_chlo_legalize_to_hlo.inc" #include "generated_chlo_legalize_to_hlo.inc"
} // namespace } // namespace
@ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
// Instantiate conversion templates for conforming binary elementwise ops // Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do // that do not have different dtypes between operands and results and do
// not have special attributes that need to be preserved. // not have special attributes that need to be preserved.
#define POPULATE_BCAST(ChloOp, HloOp) \ PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
PopulateForBinaryOp<ChloOp, HloOp, \ context, patterns, 10);
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \ PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
patterns); context, patterns, 5);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
// Broadcasting ops requiring special construction.
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
context, patterns);
// Other patterns. // Other patterns.
patterns->insert<ConvertConstantLikeOp>(context); patterns->insert<ConvertConstantLikeOp>(context);

View File

@ -49,7 +49,7 @@ struct ChloLegalizeToHloPass
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
if (failed(applyPartialConversion(getFunction(), conversionTarget, if (failed(applyPartialConversion(getFunction(), conversionTarget,
conversionPatterns))) { std::move(conversionPatterns)))) {
return signalPassFailure(); return signalPassFailure();
} }
} }

View File

@ -31,7 +31,8 @@ def : Pat<(HLOClient_AcosOp $input),
(HLO_CompareOp (HLO_CompareOp
$input, $input,
(HLO_ConstantLike<"-1"> $input), (HLO_ConstantLike<"-1"> $input),
HLO_COMPARISON_DIRECTION_NE HLO_COMPARISON_DIRECTION_NE,
(HLO_DEFAULT_COMPARISON_TYPE)
), ),
(HLO_MulOp (HLO_MulOp
(HLO_ConstantLike<"2"> $input), (HLO_ConstantLike<"2"> $input),
@ -67,7 +68,8 @@ def : Pat<(HLOClient_SinhOp $input),
(HLO_CompareOp (HLO_CompareOp
(HLO_AbsOp $input), (HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input), (HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_LT HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)
), ),
(HLO_DivOp (HLO_DivOp
(HLO_SubOp (HLO_SubOp

View File

@ -42,7 +42,7 @@ namespace mhlo {
namespace { namespace {
template <typename T> template <typename T>
using BaseOpConversion = BufferizeOpConversionPattern<T>; using BaseOpConversion = OpConversionPattern<T>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand, Value shape_operand,
@ -206,7 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride // Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is // and size of the target dimension if size-1 dimension expansion is
// necessary. // necessary.
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>(); auto operand_type = operand.getType().cast<MemRefType>();
@ -259,8 +259,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout, makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext())); /*offset=*/0, b->getContext()));
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>( SmallVector<int64_t, 2> static_sizes(sizes.size(),
loc, type_erased_memref_type, operand, sizes, strides); ShapedType::kDynamicSize);
SmallVector<int64_t, 2> static_strides(strides.size(),
ShapedType::kDynamicStrideOrOffset);
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
loc, type_erased_memref_type, operand, /*offset=*/0, static_sizes,
static_strides, llvm::None, sizes, strides);
return transformed_operand; return transformed_operand;
} }
}; };
@ -284,7 +289,7 @@ struct HloToLhloDynamicReshapeConverter
return failure(); return failure();
} }
mhlo::DynamicReshapeOp::Adaptor adaptor(operands); mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>( rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape()); op, result_type, adaptor.operand(), adaptor.output_shape());
return success(); return success();
} }
@ -504,12 +509,7 @@ struct HloLegalizeToLhlo
public: public:
HloLegalizeToLhlo() = default; HloLegalizeToLhlo() = default;
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) { HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
this->results_escape_function = o.results_escape_function.getValue();
}
explicit HloLegalizeToLhlo(bool results_escape_function) {
this->results_escape_function.setValue(results_escape_function);
}
void runOnOperation() override { void runOnOperation() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
@ -541,33 +541,17 @@ struct HloLegalizeToLhlo
return std::all_of(op.operand_type_begin(), op.operand_type_end(), return std::all_of(op.operand_type_begin(), op.operand_type_end(),
isMemRefType); isMemRefType);
}); });
target.addDynamicallyLegalOp<shape::AssumingOp>([&](shape::AssumingOp op) {
return std::all_of(op.result_type_begin(), op.result_type_end(),
isMemRefType);
});
auto kind = results_escape_function
? BufferizeTypeConverter::KeepAsFunctionResult
: BufferizeTypeConverter::AppendToArgumentsList;
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
kind);
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp, populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
lmhlo::CopyOp>( lmhlo::CopyOp>(
&context, converter, patterns); &context, converter, patterns);
populateShapeTypeConversionPatterns(&context, converter, patterns); populateShapeStructuralTypeConversionsAndLegality(&context, converter,
if (failed(applyPartialConversion(getOperation(), target, patterns))) patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure(); signalPassFailure();
} }
private:
Option<bool> results_escape_function{
*this, "results-escape-function",
llvm::cl::desc(
"Allocate the results of functions within the functions body"),
llvm::cl::init(false)};
}; };
} // namespace } // namespace
@ -623,13 +607,12 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloReturnOpConverter, HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter HloToLhloTensorStoreOpConverter
>(context, *converter); >(context);
// clang-format on // clang-format on
} }
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass( std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
bool results_escape_function) { return std::make_unique<HloLegalizeToLhlo>();
return std::make_unique<HloLegalizeToLhlo>(results_escape_function);
} }
} // namespace mhlo } // namespace mhlo

View File

@ -17,8 +17,8 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir { namespace mlir {
@ -133,7 +133,7 @@ struct LegalizeGatherToTorchIndexSelectPass
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns); PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }
}; };
} // namespace } // namespace

View File

@ -60,13 +60,13 @@ ShapedType getHloOpResultType(Operation* op) {
template <bool isLHLO = true> template <bool isLHLO = true>
bool verifyHloOpBufferOrTensorSemantics(Operation* op) { bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
auto verifyType = [&](Value val) -> bool { auto verify_type = [&](Value val) -> bool {
return (isLHLO && val.getType().isa<MemRefType>()) || return (isLHLO && val.getType().isa<MemRefType>()) ||
(!isLHLO && val.getType().isa<RankedTensorType>()); (!isLHLO && val.getType().isa<RankedTensorType>());
}; };
if (!llvm::all_of(op->getOperands(), verifyType)) return false; if (!llvm::all_of(op->getOperands(), verify_type)) return false;
return isLHLO ? op->getResults().empty() return isLHLO ? op->getResults().empty()
: llvm::all_of(op->getResults(), verifyType); : llvm::all_of(op->getResults(), verify_type);
} }
template <typename OpTy, bool isLHLO = true> template <typename OpTy, bool isLHLO = true>
@ -99,51 +99,51 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
<< nloops << " parallel iterators: " << *(op.getOperation()); << nloops << " parallel iterators: " << *(op.getOperation());
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes; SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
// This doesnt account for implicit broadcast, but the working assumption // This doesnt account for implicit broadcast, but the working assumption
// in HLO/LHLO is that are broadcasts are made explicit. // in HLO/LHLO is that are broadcasts are made explicit.
if (isLHLO && !nloops) return failure(); if (isLHLO && !nloops) return failure();
int numInputs = (isLHLO ? args.size() - 1 : args.size()); int num_inputs = (isLHLO ? args.size() - 1 : args.size());
ValueRange inputs(args.take_front(numInputs)); ValueRange inputs(args.take_front(num_inputs));
for (Value in : inputs) for (Value in : inputs)
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
ValueRange outputBuffers(args.take_back(args.size() - numInputs)); ValueRange output_buffers(args.take_back(args.size() - num_inputs));
for (Value out : outputBuffers) for (Value out : output_buffers)
bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType())); body_result_types.emplace_back(getElementTypeOrSelf(out.getType()));
if (!isLHLO) { if (!isLHLO) {
// HLO operations have return as tensor types. // HLO operations have return as tensor types.
assert(bodyResultTypes.empty() && assert(body_result_types.empty() &&
"When lowering HLO ops result can't be part of arguments"); "When lowering HLO ops result can't be part of arguments");
Value result = op.getOperation()->getResult(0); Value result = op.getOperation()->getResult(0);
bodyResultTypes.push_back(getElementTypeOrSelf(result)); body_result_types.push_back(getElementTypeOrSelf(result));
opResultTypes.push_back(result.getType()); op_result_types.push_back(result.getType());
} }
AffineMap commonIndexingMap = AffineMap common_indexing_map =
nloops ? rewriter.getMultiDimIdentityMap(nloops) nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext()); : AffineMap::get(nloops, 0, rewriter.getContext());
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1), SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
commonIndexingMap); common_indexing_map);
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, inputs, outputBuffers, loc, op_result_types, inputs, output_buffers,
/*initTensors=*/ValueRange{}, indexing_maps, /*initTensors=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace. // TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there. // That method needs to be moved out of there.
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>( Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes, op, body_result_types,
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult); nested_builder.create<linalg::YieldOp>(loc, op_result);
}); });
rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success(); return success();
} }
}; };
@ -157,10 +157,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
LhloOp lhlo_op, ArrayRef<Value> args, LhloOp lhlo_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc(); auto loc = lhlo_op.getLoc();
auto argType = auto arg_type =
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>(); lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.getElementType().isSignlessIntOrFloat() || if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
(argType.getRank() != 0)) { (arg_type.getRank() != 0)) {
return failure(); return failure();
} }
@ -168,10 +168,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of lmhlo namespace. // TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>( Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter); &rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out()); rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
rewriter.eraseOp(lhlo_op); rewriter.eraseOp(lhlo_op);
return success(); return success();
} }
@ -192,52 +192,52 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
lmhlo::ConvOp op, ArrayRef<Value> args, lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information. // Check validity of dimension information.
if (const mhlo::ConvDimensionNumbers& dimensionNumbers = if (const mhlo::ConvDimensionNumbers& dimension_numbers =
op.dimension_numbers()) { op.dimension_numbers()) {
const int inputSpatialRank = const int input_spatial_rank =
llvm::size(dimensionNumbers.input_spatial_dimensions()); llvm::size(dimension_numbers.input_spatial_dimensions());
// The dimensions for input should follow the order of // The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count. // batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 || if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() != dimension_numbers.input_feature_dimension().getInt() !=
(inputSpatialRank + 1)) (input_spatial_rank + 1))
return failure(); return failure();
const int kernelSpatialRank = const int kernel_spatial_rank =
llvm::size(dimensionNumbers.kernel_spatial_dimensions()); llvm::size(dimension_numbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of // The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count. // spatial_dims..., input_feature_count, num_output_feature_count.
if (dimensionNumbers.kernel_input_feature_dimension().getInt() != if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank || kernel_spatial_rank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() != dimension_numbers.kernel_output_feature_dimension().getInt() !=
(kernelSpatialRank + 1)) (kernel_spatial_rank + 1))
return failure(); return failure();
const int outputSpatialRank = const int output_spatial_rank =
llvm::size(dimensionNumbers.output_spatial_dimensions()); llvm::size(dimension_numbers.output_spatial_dimensions());
// The dimensions for output should follow the order of // The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count. // batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 || if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() != dimension_numbers.output_feature_dimension().getInt() !=
(outputSpatialRank + 1)) (output_spatial_rank + 1))
return failure(); return failure();
if (inputSpatialRank != outputSpatialRank || if (input_spatial_rank != output_spatial_rank ||
inputSpatialRank != kernelSpatialRank) input_spatial_rank != kernel_spatial_rank)
return failure(); return failure();
auto inputSpatialDim = auto input_spatial_dim =
dimensionNumbers.input_spatial_dimensions().begin(); dimension_numbers.input_spatial_dimensions().begin();
auto kernelSpatialDim = auto kernel_spatial_dim =
dimensionNumbers.kernel_spatial_dimensions().begin(); dimension_numbers.kernel_spatial_dimensions().begin();
auto outputSpatialDim = auto output_spatial_dim =
dimensionNumbers.output_spatial_dimensions().begin(); dimension_numbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly. // Check if spatial dims are ordered correctly.
for (int i = 0; i < inputSpatialRank; ++i) { for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1; const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim || if ((*input_spatial_dim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim || (*output_spatial_dim++).getZExtValue() != dim ||
(*kernelSpatialDim++).getZExtValue() != i) (*kernel_spatial_dim++).getZExtValue() != i)
return failure(); return failure();
} }
} }
@ -248,33 +248,33 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
} }
llvm::SmallVector<Attribute, 4> strides; llvm::SmallVector<Attribute, 4> strides;
if (auto windowStrides = op.window_strides()) { if (auto window_strides = op.window_strides()) {
auto range = windowStrides->getAttributeValues(); auto range = window_strides->getAttributeValues();
strides.assign(range.begin(), range.end()); strides.assign(range.begin(), range.end());
} }
auto stridesArg = ArrayAttr::get(strides, op.getContext()); auto strides_arg = ArrayAttr::get(strides, op.getContext());
llvm::SmallVector<Attribute, 2> dilation; llvm::SmallVector<Attribute, 2> dilation;
if (auto rhsDilation = op.rhs_dilation()) { if (auto rhs_dilation = op.rhs_dilation()) {
auto range = rhsDilation->getAttributeValues(); auto range = rhs_dilation->getAttributeValues();
dilation.assign(range.begin(), range.end()); dilation.assign(range.begin(), range.end());
} else { } else {
// Default dilation of 1. // Default dilation of 1.
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
} }
auto dilationArg = ArrayAttr::get(dilation, op.getContext()); auto dilation_arg = ArrayAttr::get(dilation, op.getContext());
// Set padding only if it is non-zero. // Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr(); DenseIntElementsAttr padding = op.paddingAttr();
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) { if (!padding ||
return !intVal.isNullValue(); !llvm::any_of(padding.getValues<APInt>(),
})) { [](APInt int_val) { return !int_val.isNullValue(); })) {
padding = nullptr; padding = nullptr;
} }
// The order of input and filter are switched with linalg.conv. // The order of input and filter are switched with linalg.conv.
rewriter.replaceOpWithNewOp<linalg::ConvOp>( rewriter.replaceOpWithNewOp<linalg::ConvOp>(
op, args[1], args[0], args[2], stridesArg, dilationArg, padding); op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
return success(); return success();
} }
}; };
@ -293,25 +293,25 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
OpTy op, ArrayRef<Value> args, OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure(); if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto resultType = getHloOpResultType<isLHLO>(op); auto result_type = getHloOpResultType<isLHLO>(op);
SmallVector<AffineMap, 2> indexing_maps = SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter); Derived::getIndexingMaps(op, &rewriter);
if (indexing_maps.empty()) return failure(); if (indexing_maps.empty()) return failure();
auto nloops = resultType.getRank(); auto nloops = result_type.getRank();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, loc,
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType, /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
/*inputs=*/args.front(), /*inputs=*/args.front(),
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
/*initTensor=*/ValueRange{}, indexing_maps, /*initTensor=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); nested_builder.create<linalg::YieldOp>(loc, *args.begin());
}); });
rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success(); return success();
} }
}; };
@ -325,32 +325,32 @@ class BroadcastConverter
using DataMovementOpConverter<BroadcastConverter, OpTy, using DataMovementOpConverter<BroadcastConverter, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp, static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
Builder* b) { Builder* b) {
ShapedType inputType = ShapedType input_type =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcast_op.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank(); unsigned input_rank = input_type.getRank();
unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank(); unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions. // the input's dimensions.
unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
SmallVector<AffineExpr, 4> inputDimExprs; SmallVector<AffineExpr, 4> input_dim_exprs;
inputDimExprs.reserve(inputRank); input_dim_exprs.reserve(input_rank);
for (int i = 0; i < inputRank; ++i) { for (int i = 0; i < input_rank; ++i) {
inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
} }
AffineMap inputMap; AffineMap input_map;
MLIRContext* context = b->getContext(); MLIRContext* context = b->getContext();
if (inputDimExprs.empty()) { if (input_dim_exprs.empty()) {
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
} else { } else {
inputMap = input_map =
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
} }
return {inputMap, b->getMultiDimIdentityMap(nloops)}; return {input_map, b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -363,34 +363,34 @@ class HloBroadcastInDimConverter
false>::DataMovementOpConverter; false>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps( static SmallVector<AffineMap, 2> getIndexingMaps(
mhlo::BroadcastInDimOp broadcastOp, Builder* b) { mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
auto resultType = getHloOpResultType<false>(broadcastOp); auto result_type = getHloOpResultType<false>(broadcast_op);
auto operandType = auto operand_type =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcast_op.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank(); unsigned nloops = result_type.getRank();
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) { if (operand_type.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
auto operandShape = operandType.getShape(); auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dimExprs; SmallVector<AffineExpr, 4> dim_exprs;
dimExprs.reserve(nloops); dim_exprs.reserve(nloops);
if (broadcastOp.broadcast_dimensions()) { if (broadcast_op.broadcast_dimensions()) {
for (const auto& broadcastDim : for (const auto& broadcastDim :
enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
int size = broadcastDim.value().getSExtValue(); int size = broadcastDim.value().getSExtValue();
bool expansion_needed = operandShape[broadcastDim.index()] == 1 && bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
resultType.getShape()[size] != 1; result_type.getShape()[size] != 1;
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size)); : b->getAffineDimExpr(size));
} }
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -430,8 +430,8 @@ class LhloBroadcastInDimConverter
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, /*outputBuffers=*/ValueRange{operand_adaptor.output()},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, val); nested_builder.create<linalg::YieldOp>(loc, val);
}); });
} else { } else {
@ -441,8 +441,8 @@ class LhloBroadcastInDimConverter
loc, /*inputs=*/ValueRange{operand}, loc, /*inputs=*/ValueRange{operand},
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); nested_builder.create<linalg::YieldOp>(loc, *args.begin());
}); });
} }
rewriter.replaceOp(op, llvm::None); rewriter.replaceOp(op, llvm::None);
@ -520,35 +520,35 @@ class LhloBroadcastInDimConverter
} }
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op, SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims, ArrayRef<int64_t> broadcast_dims,
ArrayRef<int64_t> resultShape, ArrayRef<int64_t> result_shape,
MemRefType operandType, MemRefType operand_type,
Builder* b) const { Builder* b) const {
unsigned nloops = resultShape.size(); unsigned nloops = result_shape.size();
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) { if (operand_type.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
auto operandShape = operandType.getShape(); auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dimExprs; SmallVector<AffineExpr, 4> dim_exprs;
dimExprs.reserve(nloops); dim_exprs.reserve(nloops);
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) { for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
int size = broadcastDim.value(); int size = broadcast_dim.value();
bool expansion_needed = bool expansion_needed =
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1; operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
if (expansion_needed) { if (expansion_needed) {
op.emitOpError( op.emitOpError(
"BroadcastInDimOp lowering to Linalg does not support size-1 " "BroadcastInDimOp lowering to Linalg does not support size-1 "
"dimensions expansion."); "dimensions expansion.");
} }
dimExprs.push_back(b->getAffineDimExpr(size)); dim_exprs.push_back(b->getAffineDimExpr(size));
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -561,17 +561,17 @@ class TransposeConverter
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy, using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> inputExprs; SmallVector<AffineExpr, 2> input_exprs;
inputExprs.resize(resultType.getRank()); input_exprs.resize(result_type.getRank());
for (auto permutation : llvm::enumerate(op.permutation())) { for (auto permutation : llvm::enumerate(op.permutation())) {
inputExprs[permutation.value().getZExtValue()] = input_exprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index()); b->getAffineDimExpr(permutation.index());
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -584,101 +584,104 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy reshapeOp, ArrayRef<Value> args, OpTy reshape_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp)) if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
return failure(); return failure();
ShapedType operandType = ShapedType operand_type =
reshapeOp.operand().getType().template cast<ShapedType>(); reshape_op.operand().getType().template cast<ShapedType>();
ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp); ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op);
if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure(); return failure();
// Compute the reassociation maps for the linalg operation. // Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> srcShape = ArrayRef<int64_t> src_shape =
(operandType.getRank() > resultType.getRank() ? operandType.getShape() (operand_type.getRank() > result_type.getRank()
: resultType.getShape()); ? operand_type.getShape()
ArrayRef<int64_t> dstShape = : result_type.getShape());
(operandType.getRank() > resultType.getRank() ? resultType.getShape() ArrayRef<int64_t> dst_shape =
: operandType.getShape()); (operand_type.getRank() > result_type.getRank()
unsigned currSrcDim = 0, currDstDim = 0; ? result_type.getShape()
SmallVector<linalg::ReassociationExprs, 4> reassociationMap( : operand_type.getShape());
dstShape.size()); unsigned curr_src_dim = 0, curr_dst_dim = 0;
bool isExpandingOrCollapsing = true; SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { dst_shape.size());
int64_t dstSize = dstShape[currDstDim]; bool is_expanding_or_collapsing = true;
int64_t srcSize = srcShape[currSrcDim]; while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
while (srcSize < dstSize && currSrcDim < srcShape.size()) { int64_t dst_size = dst_shape[curr_dst_dim];
reassociationMap[currDstDim].push_back( int64_t src_size = src_shape[curr_src_dim];
rewriter.getAffineDimExpr(currSrcDim++)); while (src_size < dst_size && curr_src_dim < src_shape.size()) {
srcSize *= srcShape[currSrcDim]; reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(curr_src_dim++));
src_size *= src_shape[curr_src_dim];
} }
if (srcSize == dstSize) { if (src_size == dst_size) {
reassociationMap[currDstDim].push_back( reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(currSrcDim++)); rewriter.getAffineDimExpr(curr_src_dim++));
// If the next dim in dstShape is not 1, treat subsequent dims in // If the next dim in dst_shape is not 1, treat subsequent dims in
// srcShape which are 1 to be collapsed. // src_shape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || if (curr_dst_dim == dst_shape.size() - 1 ||
dstShape[currDstDim + 1] != 1) { dst_shape[curr_dst_dim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { while (curr_src_dim < src_shape.size() &&
reassociationMap[currDstDim].push_back( src_shape[curr_src_dim] == 1) {
rewriter.getAffineDimExpr(currSrcDim++)); reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(curr_src_dim++));
} }
} }
} else { } else {
isExpandingOrCollapsing = false; is_expanding_or_collapsing = false;
break; break;
} }
currDstDim++; curr_dst_dim++;
} }
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
isExpandingOrCollapsing = false; is_expanding_or_collapsing = false;
if (!isExpandingOrCollapsing) { if (!is_expanding_or_collapsing) {
auto getIdentityExprs = [&rewriter](int n) { auto get_identity_exprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs; SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i)); exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs; return exprs;
}; };
Location loc = reshapeOp.getLoc(); Location loc = reshape_op.getLoc();
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
std::multiplies<int64_t>()); 1, std::multiplies<int64_t>());
auto elemType = operandType.getElementType(); auto elem_type = operand_type.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = { SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
getIdentityExprs(dstShape.size())}; get_identity_exprs(dst_shape.size())};
SmallVector<linalg::ReassociationExprs, 4> expandingMap = { SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
getIdentityExprs(srcShape.size())}; get_identity_exprs(src_shape.size())};
if (isLHLO) { if (isLHLO) {
auto collapsedType = MemRefType::get({totalElems}, elemType); auto collapsed_type = MemRefType::get({total_elems}, elem_type);
Value collapsedOp = rewriter.create<linalg::ReshapeOp>( Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
loc, collapsedType, args[0], collapsingMap); loc, collapsed_type, args[0], collapsing_map);
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
loc, resultType, collapsedOp, expandingMap); loc, result_type, collapsed_op, expanding_map);
rewriter.replaceOpWithNewOp<linalg::CopyOp>( rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr); /*outputPermutation =*/nullptr);
} else { } else {
auto collapsedType = RankedTensorType::get({totalElems}, elemType); auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>( Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsedType, args[0], collapsingMap); loc, collapsed_type, args[0], collapsing_map);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, collapsedOp, expandingMap); reshape_op, result_type, collapsed_op, expanding_map);
} }
return success(); return success();
} }
if (isLHLO) { if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
reshapeOp.getLoc(), resultType, args[0], reassociationMap); reshape_op.getLoc(), result_type, args[0], reassociation_map);
rewriter.replaceOpWithNewOp<linalg::CopyOp>( rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr); /*outputPermutation =*/nullptr);
} else { } else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, args[0], reassociationMap); reshape_op, result_type, args[0], reassociation_map);
} }
return success(); return success();
} }
@ -690,42 +693,42 @@ class IotaConverter : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy iotaOp, ArrayRef<Value> args, OpTy iota_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp); ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op);
if (!resultShapedType) return failure(); if (!result_shaped_type) return failure();
auto resultElementType = resultShapedType.getElementType(); auto result_element_type = result_shaped_type.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure(); if (!result_element_type.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultShapedType.getRank(); unsigned nloops = result_shaped_type.getRank();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>( auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), iota_op.getLoc(),
/*resultTensorTypes=*/ /*resultTensorTypes=*/
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType}, isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
/*inputs=*/ValueRange{}, /*inputs=*/ValueRange{},
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
/*initTensors=*/ValueRange{}, /*initTensors=*/ValueRange{},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
ValueRange args) { ValueRange args) {
Value castOp = nestedBuilder.create<IndexCastOp>( Value cast_op = nested_builder.create<IndexCastOp>(
nestedLoc, ivs[iotaOp.iota_dimension()], nested_loc, ivs[iota_op.iota_dimension()],
nestedBuilder.getIntegerType( nested_builder.getIntegerType(
resultElementType.getIntOrFloatBitWidth())); result_element_type.getIntOrFloatBitWidth()));
if (resultElementType.template isa<FloatType>()) { if (result_element_type.template isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp, cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
resultElementType); result_element_type);
} }
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp); nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
}); });
if (isLHLO) if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None); rewriter.replaceOp(iota_op, llvm::None);
else else
rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); rewriter.replaceOp(iota_op, linalg_op.result_tensors());
return success(); return success();
} }
}; };
@ -735,16 +738,106 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::ConstOp constOp, ArrayRef<Value> args, lmhlo::ConstOp const_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc(); auto loc = const_op.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>(); auto value_attr = const_op.value().cast<DenseElementsAttr>();
if (valueAttr.getType().getRank() != 0) return failure(); if (value_attr.getType().getRank() != 0) return failure();
auto stdConstOp = auto std_const_op =
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({})); rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(), rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
ValueRange()); const_op.getOperand(), ValueRange());
rewriter.eraseOp(constOp); rewriter.eraseOp(const_op);
return success();
}
};
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
public:
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = reduce_op.getLoc();
lmhlo::ReduceOp::Adaptor adaptor(args);
auto operand_shape =
adaptor.operands()[0].getType().template dyn_cast<ShapedType>();
if (!operand_shape || !operand_shape.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects known-rank args");
return failure();
}
// First fill the output buffer with the init value.
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
SmallVector<int, 4> reduction_dims;
for (const auto& dim : dimensions_attr.getIntValues()) {
reduction_dims.push_back(dim.getSExtValue());
}
SmallVector<AffineExpr, 2> src_exprs;
SmallVector<AffineExpr, 2> dst_exprs;
SmallVector<StringRef, 4> types;
for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
bool is_reduced = llvm::is_contained(reduction_dims, i);
types.push_back(is_reduced ? getReductionIteratorTypeName()
: getParallelIteratorTypeName());
src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
if (!is_reduced) {
dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
}
auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/ArrayRef<Type>{},
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
/*initTensors=*/ValueRange{}, maps, types);
linalg_op.region().takeBody(reduce_op.body());
{
OpBuilder::InsertionGuard region_guard(rewriter);
Block* block = linalg_op.getBody();
rewriter.setInsertionPoint(&block->front());
// The incoming region is operating on buffers, while linalg.generic
// expects scalar SSA values. Add some allocs around the original op to
// make it compatible.
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
// Now turn the existing signature
// (memref<X>, memref<X>, memref<X>) -> ()
// into
// (X, X) -> X
TypeConverter::SignatureConversion signature_converter(3);
signature_converter.remapInput(0, alloc_a);
signature_converter.remapInput(1, alloc_b);
signature_converter.remapInput(2, alloc_res);
signature_converter.addInputs(
{arg_type.getElementType(), arg_type.getElementType()});
Block* entry_block = rewriter.applySignatureConversion(
&linalg_op.region(), signature_converter);
// Store the arguments into the newly allocated buffers.
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
rewriter.replaceOp(entry_block->getTerminator(), {});
// Load & yield the result.
rewriter.setInsertionPointToEnd(entry_block);
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
}
rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
return success(); return success();
} }
}; };
@ -758,21 +851,21 @@ class ReverseConverter
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> inputExprs; SmallVector<AffineExpr, 2> input_exprs;
inputExprs.reserve(nloops); input_exprs.reserve(nloops);
for (int i = 0; i < nloops; ++i) for (int i = 0; i < nloops; ++i)
inputExprs.push_back(b->getAffineDimExpr(i)); input_exprs.push_back(b->getAffineDimExpr(i));
for (auto dim : op.dimensions()) { for (auto dim : op.dimensions()) {
int i = dim.getZExtValue(); int i = dim.getZExtValue();
if (resultType.isDynamicDim(i)) return {}; if (result_type.isDynamicDim(i)) return {};
int n = resultType.getShape()[i]; int n = result_type.getShape()[i];
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -782,31 +875,31 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern; using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::SliceOp sliceOp, ArrayRef<Value> args, lmhlo::SliceOp slice_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = sliceOp.getLoc(); auto loc = slice_op.getLoc();
auto argType = auto arg_type =
sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>(); slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.hasRank()) { if (!arg_type || !arg_type.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects known-rank args"); emitError(loc, "lhlo to linalg conversion expects known-rank args");
return failure(); return failure();
} }
SmallVector<Value, 3> ranges; SmallVector<Value, 3> ranges;
for (int i = 0, e = argType.getRank(); i < e; ++i) { for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
Value start_index = rewriter.create<ConstantIndexOp>( Value start_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.start_indices().getValue<int64_t>(i)); loc, slice_op.start_indices().getValue<int64_t>(i));
Value limit_index = rewriter.create<ConstantIndexOp>( Value limit_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.limit_indices().getValue<int64_t>(i)); loc, slice_op.limit_indices().getValue<int64_t>(i));
Value stride = rewriter.create<ConstantIndexOp>( Value stride = rewriter.create<ConstantIndexOp>(
loc, sliceOp.strides().getValue<int64_t>(i)); loc, slice_op.strides().getValue<int64_t>(i));
ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index, ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
limit_index, stride)); limit_index, stride));
} }
auto linalg_slice = auto linalg_slice =
rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges); rewriter.create<linalg::SliceOp>(loc, slice_op.getOperand(0), ranges);
rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1)); rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
rewriter.eraseOp(sliceOp); rewriter.eraseOp(slice_op);
return success(); return success();
} }
}; };
@ -850,9 +943,11 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::SubOp>, PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>, PointwiseToLinalgConverter<lmhlo::TanhOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>, PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
ReduceConverter,
ReshapeOpConverter<lmhlo::ReshapeOp>, ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>, ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<lmhlo::AddOp>, ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
SliceConverter, SliceConverter,
TransposeConverter<lmhlo::TransposeOp> TransposeConverter<lmhlo::TransposeOp>
>(context); >(context);
@ -890,7 +985,7 @@ struct LhloLegalizeToLinalgPass
auto func = getFunction(); auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();
} }
} }
@ -909,7 +1004,7 @@ struct HloLegalizeToLinalgPass
auto func = getFunction(); auto func = getFunction();
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();
} }
} }

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir { namespace mlir {
namespace { namespace {
@ -201,7 +201,7 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
void LegalizeToStandardPass::runOnFunction() { void LegalizeToStandardPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }
} // end namespace mhlo } // end namespace mhlo

Some files were not shown because too many files have changed in this diff Show More