fix merge conflict

This commit is contained in:
Zhoulong Jiang 2020-10-23 16:05:47 +00:00
commit 27dd6a502c
1783 changed files with 49420 additions and 19360 deletions

View File

@ -159,6 +159,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build. # environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_openmp=true
build:mkl -c opt build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool. # config to build OneDNN backend with a user specified threadpool.
@ -172,6 +173,7 @@ build:mkl_threadpool -c opt
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only --define=build_with_openmp=true
build:mkl_opensource_only -c opt build:mkl_opensource_only -c opt
# Config setting to build with oneDNN for Arm. # Config setting to build with oneDNN for Arm.
@ -283,7 +285,7 @@ build:ios --copt=-w
build:linux --copt=-w build:linux --copt=-w
build:linux --host_copt=-w build:linux --host_copt=-w
build:macos --copt=-w build:macos --copt=-w
build:windows --copt=/w build:windows --copt=/W0
# Tensorflow uses M_* math constants that only get defined by MSVC headers if # Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined. # _USE_MATH_DEFINES is defined.
@ -294,9 +296,11 @@ build:windows --host_copt=/D_USE_MATH_DEFINES
build:linux --define=PREFIX=/usr build:linux --define=PREFIX=/usr
build:linux --define=LIBDIR=$(PREFIX)/lib build:linux --define=LIBDIR=$(PREFIX)/lib
build:linux --define=INCLUDEDIR=$(PREFIX)/include build:linux --define=INCLUDEDIR=$(PREFIX)/include
build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
build:macos --define=PREFIX=/usr build:macos --define=PREFIX=/usr
build:macos --define=LIBDIR=$(PREFIX)/lib build:macos --define=LIBDIR=$(PREFIX)/lib
build:macos --define=INCLUDEDIR=$(PREFIX)/include build:macos --define=INCLUDEDIR=$(PREFIX)/include
build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows. # TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode. # By default, build TF in C++ 14 mode.

View File

@ -103,23 +103,22 @@ open-source software development:
### Official Builds ### Official Builds
Build Type | Status | Artifacts Build Type | Status | Artifacts
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- ----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) **macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) **Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
### Community Supported Builds ### Community Supported Builds
@ -151,6 +150,7 @@ Build Type
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) * [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the * [TensorFlow Chat Room on StackOverflow (not actively monitored by the
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow) TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
* [TensorFlow Blog](https://blog.tensorflow.org) * [TensorFlow Blog](https://blog.tensorflow.org)

View File

@ -1,3 +1,35 @@
# Release 2.5.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
## Breaking Changes
* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
## Known Caveats
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES).>
* <ADDING/BUMPING DEPENDENCIES SHOULD GO HERE>
* <KNWON LACK OF SUPPORT ON SOME PLATFORM, SHOULD GO HERE>
## Major Features and Improvements
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.4.0 # Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES> <INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
@ -6,6 +38,15 @@
* <DOCUMENT BREAKING CHANGES HERE> * <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES> * <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
matmuls and convolutions, due to the use of
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
bits of precision. This is unlikely to cause issues in practice for deep
learning models. In some cases, TensorFloat-32 is also used for complex64 ops.
TensorFloat-32 can be disabled by running
`config.experimental.enable_tensor_float_32_execution(False)`. The "Major
Features and Improvements" section has more details.
* The byte layout for string tensors across the C-API has been updated to match * The byte layout for string tensors across the C-API has been updated to match
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s. TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and * C-API functions `TF_StringDecode`, `TF_StringEncode`, and
@ -54,6 +95,42 @@
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please * `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
use `tf.data.Dataset.from_tensor_slices` instead. use `tf.data.Dataset.from_tensor_slices` instead.
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
`tf.distribute.StrategyExtended.batch_reduce_to`,
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
`tf.distribute.experimental.CollectiveHints` is renamed
`tf.distribute.experimental.CommunicationOptions`.
`tf.distribute.experimental.CollectiveCommunication` is renamed
`tf.distribute.experimental.CommunicationImplementation`.
* `tf.keras.mixed_precision.experimental`:
* `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
dtype it will be casted to.
* When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
float16 or bfloat16 tensor instead of a float32 tensor.
* The property
`tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now
a tensor, not a `LossScale` object. This means to get a loss scale of a
`LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead
of `opt.loss_scale()`.
* The property `should_cast_variables` has been removed from
`tf.keras.mixed_precision.experimental.Policy`
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the
`DynamicLossScale`'s multiplier must be 2.
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of
being reused. This means modifying the weights of the `DynamicLossScale`
will no longer affect the weights of the LossScaleOptimizer, and vice versa.
* The global policy can no longer be set to a non-floating point policy in
`tf.keras.mixed_precision.experimental.set_policy`
* In `Layer.call`, `AutoCastVariable`s will no longer be casted within
`MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a
thread local variable is used to determine whether `AutoCastVariable`s are
casted, and those two functions run with a different thread. Note this only
applies if one of these two functions is called within `Layer.call`; if one
of those two functions calls `Layer.call`, `AutoCastVariable`s will still be
casted.
## Known Caveats ## Known Caveats
@ -65,9 +142,40 @@
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER> * <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy. * A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. * A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
* Support for
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a
math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as
matrix multiplications and convolutions, to run much faster on Ampere GPUs but
with reduced precision. This reduced precision has not been found to effect
convergence quality of deep learning models in practice. TensorFloat-32 is
enabled by default, but can be disabled with
`tf.config.experimental.enable_tensor_float_32_execution`.
* `tf.distribute`: * `tf.distribute`:
* `MultiWorkerMirroredStrategy` is graduated out of experimental.
* Peer failure will no longer cause the cluster to hang.
* Major issues with saving are fixed.
* See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
* The `tf.keras.mixed_precision` API has been made non-experimental. The major
changes to the new non-experimental API are:
* `tf.keras.mixed_precision.Policy` no longer takes in a
`tf.mixed_precision.experimental.LossScale` in the constructor, and no
longer has a `LossScale` associated with it. Instead, `Model.compile` will
automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic
loss scaling if `Policy.name` is "mixed_float16".
* `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in
different arguments. In particular, it no longer takes in a `LossScale`, and
there is no longer a `LossScale` associated with the `LossScaleOptimizer`.
Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss
scaling. See the documentation of
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on
the differences between the experimental `LossScaleOptimizer` and the new
non-experimental `LossScaleOptimizer`.
* `tf.mixed_precision.experimental.LossScale` and its subclasses are
deprecated, as all of its functionality now exists within
`tf.keras.mixed_precision.LossScaleOptimizer`
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
@ -117,6 +225,10 @@
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
* Fixes a segfault in `tf.quantization.quantize_and_dequantize`
([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
* Fixes an undefined behavior float cast causing a crash
([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
* TF Core: * TF Core:
* `tf.types.experimental.TensorLike` is a new `Union` type that can be * `tf.types.experimental.TensorLike` is a new `Union` type that can be
used as type annotation for variables representing a Tensor or a value used as type annotation for variables representing a Tensor or a value
@ -138,6 +250,8 @@
stateful ops. stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory * Added `tf.config.experimental.get_memory_usage` to return total memory
usage of the device. usage of the device.
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
* `tf.data`: * `tf.data`:
* tf.data service: * tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and * Added new `tf.data.experimental.service.register_dataset` and
@ -182,7 +296,16 @@
how many times the function is called, and independent of global seed how many times the function is called, and independent of global seed
settings. settings.
* `tf.distribute`: * `tf.distribute`:
* <ADD RELEASE NOTES HERE> * (Experimental) Parameter server training:
* Replaced the existing
`tf.distribute.experimental.ParameterServerStrategy` symbol with
a new class that is for parameter server training in TF2. Usage with
the old symbol, usually with Estimator, should be replaced with
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
* Added `tf.distribute.experimental.coordinator.*` namespace,
including the main API `ClusterCoordinator` for coordinating the
training cluster, the related data structure `RemoteValue`
and `PerWorkerValue`.
* `tf.keras`: * `tf.keras`:
* Improvements from the functional API refactoring: * Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global * Functional model construction does not need to maintain a global
@ -217,6 +340,8 @@
* Improvements to Keras preprocessing layers: * Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an * TextVectorization can now accept a vocabulary list or file as an
init arg. init arg.
* TextVectorization, StringLookup, and IntegerLookup can now accept a
vocabulary file via the `set_vocab_from_file` method.
* Normalization can now accept mean and variance values as init args. * Normalization can now accept mean and variance values as init args.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now * In `Attention` and `AdditiveAttention` layers, the `call()` method now
accepts a `return_attention_scores` argument. When set to accepts a `return_attention_scores` argument. When set to
@ -224,6 +349,15 @@
argument. argument.
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints * Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
with the same implementation as their `tf.losses` equivalent. with the same implementation as their `tf.losses` equivalent.
* For Keras model, the individual call of `Model.evaluate` uses no cached
data for evaluation, while `Model.fit` uses cached data when
`validation_data` arg is provided for better performance.
* Added a `save_traces` argument to `model.save`/
`tf.keras.models.save_model` which determines whether the SavedModel
format stores the Keras model/layer call functions. The traced functions
allow Keras to revive custom models and layers without the original
class definition, but if this isn't required the tracing can be
disabled with the added option.
* `tf.function` / AutoGraph: * `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When * Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing True, the function may use type annotations to optimize the tracing
@ -269,6 +403,7 @@
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first * `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty. string to be joined is empty.
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* `tf.random`: * `tf.random`:
@ -277,7 +412,7 @@
* Math and Linear Algebra: * Math and Linear Algebra:
* <ADD RELEASE NOTES HERE> * Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
* TPU Enhancements: * TPU Enhancements:
@ -323,6 +458,12 @@
didn't have the keys sorted, the keys and values were not being printed didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping. in accordance with their correct mapping.
* `TensorRT`
* We now issue a warning when the `session_config` parameter for the TF1
converter is used or the `rewrite_config_template` field in the TF2
converter parameter object is used.
* Other: * Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist" * We have replaced uses of "whitelist" and "blacklist" with "allowlist"
@ -331,6 +472,8 @@
context. context.
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us * Add `tf.config.experimental.mlir_bridge_rollout` which will help us
rollout the new MLIR TPU bridge. rollout the new MLIR TPU bridge.
* Added `tf.experimental.register_filesystem_plugin` to load modular
filesystem plugins from Python
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
## Thanks to our Contributors ## Thanks to our Contributors
@ -703,6 +846,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights. * Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights.
* Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights. * Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights.
* Mutable tables now restore checkpointed values when loaded from SavedModel. * Mutable tables now restore checkpointed values when loaded from SavedModel.
* The user object metadata field in the SavedModel proto has been deprecated as part of the updates to Keras SavedModel. Keras was the only consumer of this field prior to the update.
* GPU * GPU
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities. * TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
* Remove environmental variable `TF_USE_CUDNN`. * Remove environmental variable `TF_USE_CUDNN`.
@ -731,6 +875,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses. * Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`. * Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. * Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
* Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices.
### `tf.keras`: ### `tf.keras`:
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs. * Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.

View File

@ -1163,12 +1163,9 @@ def set_system_libs_flag(environ_cp):
syslibs = ','.join(sorted(syslibs.split())) syslibs = ','.join(sorted(syslibs.split()))
write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
if 'PREFIX' in environ_cp: for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'):
write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) if varname in environ_cp:
if 'LIBDIR' in environ_cp: write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname]))
write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR'])
if 'INCLUDEDIR' in environ_cp:
write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR'])
def is_reduced_optimize_huge_functions_available(environ_cp): def is_reduced_optimize_huge_functions_available(environ_cp):

View File

@ -3,6 +3,7 @@
# learning applications. # learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load( load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
@ -238,6 +239,12 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "linux_mips64",
values = {"cpu": "mips64"},
visibility = ["//visibility:public"],
)
config_setting( config_setting(
name = "debug", name = "debug",
values = { values = {
@ -563,18 +570,45 @@ selects.config_setting_group(
], ],
) )
# 'enable_registration_v2' opts-in to a different implementation of op and
# kernel registration - REGISTER_OP, REGISTER_KERNEL_BUILDER, etc.
#
# This setting is currently experimental. The 'v2' implementation does _not_
# correspond to a particular, finalized design; rather, it relates to
# developing one.
#
# The current aim of the 'v2' implementation is to allow 'unused' ops and
# kernels to be discarded by the linker (to the benefit of binary size).
bool_flag(
name = "enable_registration_v2",
build_setting_default = False,
visibility = ["//visibility:public"],
)
config_setting(
name = "registration_v1",
flag_values = {":enable_registration_v2": "False"},
visibility = ["//visibility:public"],
)
config_setting(
name = "registration_v2",
flag_values = {":enable_registration_v2": "True"},
visibility = ["//visibility:public"],
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! # DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides. # Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs. # If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group( package_group(
name = "internal", name = "internal",
packages = ["//tensorflow/..."], packages = [
"//learning/lib/ami/simple_ml/...",
"//tensorflow/...",
],
) )
package_group( package_group(name = "ndarray_tensor_allow_list")
name = "ndarray_tensor_allow_list",
packages = ["//learning/pathways/..."],
)
# Packages that use private types symbols, until they are exported. # Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove. # TODO(b/154650521) Remove.
@ -606,6 +640,7 @@ bzl_library(
"//third_party/mkl:build_defs_bzl", "//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl", "//third_party/mkl_dnn:build_defs_bzl",
"//third_party/ngraph:build_defs_bzl", "//third_party/ngraph:build_defs_bzl",
"@bazel_skylib//rules:common_settings",
"@local_config_cuda//cuda:build_defs_bzl", "@local_config_cuda//cuda:build_defs_bzl",
"@local_config_rocm//rocm:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl",
"@local_config_tensorrt//:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl",
@ -706,6 +741,9 @@ tf_cc_shared_object(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/c:kernels_hdrs",
"//tensorflow/c:ops_hdrs",
"//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core/common_runtime:core_cpu_impl",
"//tensorflow/core:framework_internal_impl", "//tensorflow/core:framework_internal_impl",

View File

@ -202,6 +202,7 @@ tf_cuda_library(
":tf_status", ":tf_status",
":tf_tensor", ":tf_tensor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"//tensorflow/c/experimental/filesystem:modular_filesystem",
"//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients", "//tensorflow/cc:gradients",
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
@ -511,6 +512,18 @@ tf_cuda_library(
], ],
) )
cc_library(
name = "kernels_hdrs",
hdrs = ["kernels.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_internal",
":tf_datatype",
":tf_status",
":tf_tensor",
],
)
tf_cuda_library( tf_cuda_library(
name = "kernels", name = "kernels",
srcs = [ srcs = [
@ -565,6 +578,16 @@ tf_cuda_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "ops_hdrs",
hdrs = ["ops.h"],
visibility = ["//tensorflow:internal"],
deps = [
":tf_datatype",
":tf_status",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Tests # Tests

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h" // NOLINT #include "tensorflow/core/platform/platform.h" // NOLINT
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/framework/scope_internal.h"
@ -2606,4 +2607,14 @@ void TF_RegisterLogListener(void (*listener)(const char*)) {
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
} }
void TF_RegisterFilesystemPlugin(const char* plugin_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"FileSystem plugin functionality is not supported on mobile");
#else
status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
}
} // end extern "C" } // end extern "C"

View File

@ -1577,6 +1577,13 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
TF_CAPI_EXPORT extern void TF_RegisterLogListener( TF_CAPI_EXPORT extern void TF_RegisterLogListener(
void (*listener)(const char*)); void (*listener)(const char*));
// Register a FileSystem plugin from filename `plugin_filename`.
//
// On success, place OK in status.
// On failure, place an error status in status.
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
const char* plugin_filename, TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -563,15 +563,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
collective_executor_handle->get()->StartAbort(status->status); collective_executor_handle->get()->StartAbort(status->status);
} }
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
const char* task, TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle(); auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done; tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth( collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, [&done, status](const Status& s) { task, timeout_in_ms, [&done, status](const Status& s) {
status->status = s; status->status = s;
done.Notify(); done.Notify();
}); });

View File

@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
// Checks the health of collective ops peers. Explicit health check is needed in // Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is // multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang. // down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
const char* task, TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status); TF_Status* status);
// Information about the shape of a Tensor and its type. // Information about the shape of a Tensor and its type.
struct TF_ShapeAndType { struct TF_ShapeAndType {

View File

@ -10,6 +10,9 @@ load(
"tf_cuda_library", "tf_cuda_library",
) )
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow:tensorflow.bzl", "filegroup")
@ -94,6 +97,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
"//tensorflow/core:gpu_runtime", "//tensorflow/core:gpu_runtime",
] + internal_tfrt_deps(), ] + internal_tfrt_deps(),
alwayslink = 1, alwayslink = 1,
@ -638,6 +642,19 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "tfe_tensorhandle_internal_hdrs_only",
extra_deps = [
"@com_google_absl//absl/strings",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tfe_tensorhandle_internal",
],
)
tf_cuda_library( tf_cuda_library(
name = "c_api_test_util", name = "c_api_test_util",
testonly = 1, testonly = 1,

View File

@ -70,6 +70,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
// TODO(yuefengz): support partially specified `worker_name`. tensorflow::GrpcServer* grpc_server =
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client; dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
status->status = context->GetClient(worker_name, &eager_client); if (grpc_server == nullptr) {
if (!status->status.ok()) { status->status =
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
return false;
}
tensorflow::WorkerInterface* wi =
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
if (wi == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Unable to find worker interface corresponding to task ", worker_name);
return false; return false;
} }
// Send a rpc request to the worker to check aliveness. tensorflow::GetStatusRequest request;
tensorflow::eager::KeepAliveRequest request; tensorflow::GetStatusResponse response;
request.set_context_id(context->GetContextId()); tensorflow::Status remote_status;
tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status;
tensorflow::Notification done; tensorflow::Notification done;
eager_client->KeepAliveAsync( wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
&request, &response, [&remote_status, &done](const tensorflow::Status& s) {
[&keep_alive_status, &done](const tensorflow::Status& s) { remote_status = s;
keep_alive_status = s; done.Notify();
done.Notify(); });
});
done.WaitForNotification(); done.WaitForNotification();
// We set OK status so the call does not raise any exceptions. Instead, caller
// users the return value to tell if the remote worker is alive.
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
// If `context_id` doesn't exist on the remote worker, an InvalidArgument if (remote_status.ok()) {
// error will return. But this still indicates that the remote worker is
// alive.
if (keep_alive_status.ok() ||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
return true; return true;
} else {
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << keep_alive_status.error_message();
return false;
} }
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << remote_status.error_message();
return false;
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }

View File

@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
} }
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return tensorflow::unwrap(h)->DeviceType(&status->status);
}
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return tensorflow::unwrap(h)->DeviceId(&status->status);
}

View File

@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable, unsigned char enable,
TF_Status* status); TF_Status* status);
// Returns the device type of the operation that produced `h`.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
TFE_TensorHandle* h, TF_Status* status);
// Returns the device ID of the operation that produced `h`.
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
TEST(CAPI, TensorHandleNullptr) {
TFE_TensorHandle* h = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_type, nullptr);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int device_id = TFE_TensorHandleDeviceID(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_id, -1);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
}
TEST(CAPI, TensorHandleDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_DeleteOp(shape_op);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleDefaults) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
h_default, ctx, "/device:CPU:0", status.get());
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
TFE_DeleteTensorHandle(h_default);
TFE_DeleteTensorHandle(h_cpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
} }
TEST_P(GradientCheckerTest, TestGradCheckMatMul) { TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
// Computing numerical gradients with TensorFloat-32 is numerically unstable
enable_tensor_float_32_execution(false);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx; AbstractContextPtr ctx;

View File

@ -62,10 +62,11 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
return Status::OK(); return Status::OK();
} }
// Computes // Computes
// y = inputs[0] + inputs[1] // y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]}) // return grad(y, {inputs[0], inputs[1]})
@ -74,11 +75,11 @@ Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1); std::vector<AbstractTensorHandle*> add_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
absl::MakeSpan(add_outputs), absl::MakeSpan(add_outputs),
"Add")); // Compute x+y. "Add")); // Compute x+y.
@ -97,7 +98,6 @@ Status AddGradModel(AbstractContext* ctx,
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
outputs[1] = out_grads[1]; outputs[1] = out_grads[1];
delete tape;
return Status::OK(); return Status::OK();
} }
@ -109,10 +109,10 @@ Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1); std::vector<AbstractTensorHandle*> exp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp")); ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
@ -128,7 +128,6 @@ Status ExpGradModel(AbstractContext* ctx,
exp_output->Unref(); exp_output->Unref();
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
delete tape;
return Status::OK(); return Status::OK();
} }
@ -140,10 +139,10 @@ Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1); std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt")); ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
@ -159,7 +158,6 @@ Status SqrtGradModel(AbstractContext* ctx,
sqrt_output->Unref(); sqrt_output->Unref();
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
delete tape;
return Status::OK(); return Status::OK();
} }
@ -172,12 +170,12 @@ Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) { const GradientRegistry& registry) {
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); tape->Watch(ToId(inputs[0]));
tape->Watch(ToId(inputs[1])); tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2); vector<AbstractTensorHandle*> identity_n_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::IdentityN( TF_RETURN_IF_ERROR(ops::IdentityN(
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN")); tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
@ -195,7 +193,71 @@ Status IdentityNGradModel(AbstractContext* ctx,
} }
outputs[0] = out_grads[0]; outputs[0] = out_grads[0];
outputs[1] = out_grads[1]; outputs[1] = out_grads[1];
delete tape; return Status::OK();
}
// Computes
// y = - inputs[0]
// return grad(y, {inputs[0]})
Status NegGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
std::vector<AbstractTensorHandle*> neg_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto neg_output : neg_outputs) {
neg_output->Unref();
}
outputs[0] = out_grads[0];
return Status::OK();
}
// Computes
// y = inputs[0] - inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status SubGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> sub_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
absl::MakeSpan(sub_outputs),
"Sub")); // Compute x-y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sub_output : sub_outputs) {
sub_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
return Status::OK(); return Status::OK();
} }
@ -536,6 +598,111 @@ TEST_P(CppGradients, TestIdentityNGrad) {
result_tensor = nullptr; result_tensor = nullptr;
} }
TEST_P(CppGradients, TestNegGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = - x
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSubGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x - y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestSetAttrString) { TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -575,7 +742,7 @@ TEST_P(CppGradients, TestSetAttrString) {
int num_retvals = 1; int num_retvals = 1;
std::vector<AbstractTensorHandle*> outputs(1); std::vector<AbstractTensorHandle*> outputs(1);
GradientRegistry registry; GradientRegistry registry;
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false)); auto tape = std::make_unique<Tape>(/*persistent=*/false);
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
&num_retvals, &forward_op, tape.get(), registry); &num_retvals, &forward_op, tape.get(), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();

View File

@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
virtual const char* DeviceName(Status* status) const = 0; virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed. // Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(Status* status) const = 0; virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns the device type which created the handle.
virtual const char* DeviceType(Status* status) const = 0;
// Returns the device ID which created the handle.
virtual int DeviceId(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied. // Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual AbstractTensorInterface* Resolve(Status* status) = 0; virtual AbstractTensorInterface* Resolve(Status* status) = 0;

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -43,6 +44,11 @@ class CppGradients
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get()); Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message(); CHECK_EQ(errors::OK, s.code()) << s.error_message();
// Computing numerical gradients with TensorFloat-32 is numerically
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
// low tolerances
enable_tensor_float_32_execution(false);
} }
}; };

View File

@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class DeviceThread { class DeviceThread {
public: public:
// Starts a background thread waiting for `StartExecute`. // Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device) explicit DeviceThread(const std::string& device, const bool is_async)
: status_(TF_NewStatus()), : status_(TF_NewStatus()),
device_(device), device_(device),
// If the context's default exector is set to async, re-using that in // If the context's default exector is set to async, re-using that in
@ -67,7 +67,7 @@ class DeviceThread {
// //
// TODO(allenl): We should have an async API that works with the // TODO(allenl): We should have an async API that works with the
// parallel device. // parallel device.
executor_(TFE_NewExecutor(/*is_async=*/false)), executor_(TFE_NewExecutor(is_async)),
op_(nullptr), op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread( thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute", tensorflow::ThreadOptions(), "parallel_device_execute",
@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
} }
} }
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices) ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
const bool is_async)
: underlying_devices_(devices) { : underlying_devices_(devices) {
device_threads_.reserve(devices.size()); device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) { for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back( device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str())); new DeviceThread(devices[device_index].c_str(), is_async));
} }
} }

View File

@ -49,7 +49,10 @@ class DeviceThread;
// placed on each underlying device. // placed on each underlying device.
class ParallelDevice { class ParallelDevice {
public: public:
explicit ParallelDevice(const std::vector<std::string>& devices); // Eager async execution is only supported when remote eager is not in use
// (b/157523095).
explicit ParallelDevice(const std::vector<std::string>& devices,
const bool is_async = false);
~ParallelDevice(); ~ParallelDevice();

View File

@ -182,9 +182,8 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
std::string cacheKey(scheme); std::string cacheKey(scheme);
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
if (scheme == "file") { if (scheme == "file") {
libhdfs->hdfsBuilderSetNameNode(builder, nullptr); namenode = "";
} else if (scheme == "viewfs") { } else if (scheme == "viewfs") {
char* defaultFS = nullptr; char* defaultFS = nullptr;
libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS); libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS);
@ -200,21 +199,24 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
// The default NameNode configuration will be used (from the XML // The default NameNode configuration will be used (from the XML
// configuration files). See: // configuration files). See:
// https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259 // https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259
libhdfs->hdfsBuilderSetNameNode(builder, "default"); namenode = "default";
} else if (scheme == "har") { } else if (scheme == "har") {
std::string path_har = path; std::string path_har = path;
SplitArchiveNameAndPath(&path_har, &namenode, status); SplitArchiveNameAndPath(&path_har, &namenode, status);
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
cacheKey += namenode;
} else { } else {
libhdfs->hdfsBuilderSetNameNode( if (namenode.empty()) {
builder, namenode.empty() ? "default" : namenode.c_str()); namenode = "default";
cacheKey += namenode; }
} }
cacheKey += namenode;
absl::MutexLock l(&hadoop_file->connection_cache_lock); absl::MutexLock l(&hadoop_file->connection_cache_lock);
if (hadoop_file->connection_cache.find(cacheKey) == if (hadoop_file->connection_cache.find(cacheKey) ==
hadoop_file->connection_cache.end()) { hadoop_file->connection_cache.end()) {
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
libhdfs->hdfsBuilderSetNameNode(
builder, namenode.empty() ? nullptr : namenode.c_str());
auto cacheFs = libhdfs->hdfsBuilderConnect(builder); auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
if (cacheFs == nullptr) { if (cacheFs == nullptr) {
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));

View File

@ -24,6 +24,7 @@ using std::vector;
using tensorflow::ops::Conj; using tensorflow::ops::Conj;
using tensorflow::ops::MatMul; using tensorflow::ops::MatMul;
using tensorflow::ops::Mul; using tensorflow::ops::Mul;
using tensorflow::ops::Neg;
using tensorflow::ops::SqrtGrad; using tensorflow::ops::SqrtGrad;
namespace tensorflow { namespace tensorflow {
@ -201,6 +202,56 @@ class MatMulGradientFunction : public GradientFunction {
AttrBuilder forward_attrs; AttrBuilder forward_attrs;
}; };
class NegGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a Neg op Y = -X, the gradients are:
*
* dX = -U
*
*/
grad_outputs->resize(1);
std::string name = "Neg_Grad";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~NegGradientFunction() override {}
};
class SubGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a Sub op A-B, the gradients are:
*
* dA = U
* dB = -U
*
*/
grad_outputs->resize(2);
// Grad for A
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
(*grad_outputs)[0]->Ref();
// Grad for B
// negate the upstream grad
std::vector<AbstractTensorHandle*> neg_outputs(1);
std::string name = "Neg_Sub_Grad_B";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(neg_outputs), name.c_str()));
(*grad_outputs)[1] = neg_outputs[0];
return Status::OK();
}
~SubGradientFunction() override {}
};
} // namespace } // namespace
BackwardFunction* AddRegisterer(const ForwardOperation& op) { BackwardFunction* AddRegisterer(const ForwardOperation& op) {
@ -239,5 +290,23 @@ BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients); return new BackwardFunction(gradient_function, default_gradients);
} }
BackwardFunction* NegRegisterer(const ForwardOperation& op) {
auto gradient_function = new NegGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new SubGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow

View File

@ -24,6 +24,8 @@ BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op); BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op); BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
BackwardFunction* NegRegisterer(const ForwardOperation& op);
BackwardFunction* SubRegisterer(const ForwardOperation& op);
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow

View File

@ -11,11 +11,21 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
cc_library(
name = "stream_executor_hdrs",
hdrs = ["stream_executor.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
],
)
cc_library( cc_library(
name = "stream_executor", name = "stream_executor",
srcs = ["stream_executor.cc"], srcs = ["stream_executor.cc"],
hdrs = ["stream_executor.h"], hdrs = ["stream_executor.h"],
visibility = ["//visibility:public"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
":stream_executor_internal", ":stream_executor_internal",
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",

View File

@ -404,10 +404,12 @@ Status RestoreSession(const RunOptions& run_options,
const uint64 read_start_microseconds = Env::Default()->NowMicros(); const uint64 read_start_microseconds = Env::Default()->NowMicros();
std::vector<AssetFileDef> asset_file_defs; std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs)); TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, if (meta_graph.has_saver_def()) {
meta_graph.saver_def().restore_op_name(), TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
meta_graph.saver_def().filename_tensor_name(), meta_graph.saver_def().restore_op_name(),
asset_file_defs, session->get())); meta_graph.saver_def().filename_tensor_name(),
asset_file_defs, session->get()));
}
// Record walltime spent in restoring graph from disk, but postpone metric // Record walltime spent in restoring graph from disk, but postpone metric
// increments until graph init finishes. // increments until graph init finishes.
const uint64 restore_graph_walltime = const uint64 restore_graph_walltime =

View File

@ -7,6 +7,9 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow:tensorflow.bzl", "filegroup")
@ -283,6 +286,7 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -291,7 +295,7 @@ cc_library(
# Header-only version of "flags" library, for linking from the shared object # Header-only version of "flags" library, for linking from the shared object
# without ODR violations. # without ODR violations.
cc_library( cc_library(
name = "flags_headers_only", name = "flags_headers",
hdrs = ["flags.h"], hdrs = ["flags.h"],
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
@ -302,6 +306,11 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "flags_headers_only",
deps = [":flags_headers"],
)
cc_library( cc_library(
name = "common", name = "common",
srcs = [ srcs = [
@ -447,8 +456,8 @@ cc_library(
# Header-only version of "flags" library, for linking from the shared object # Header-only version of "flags" library, for linking from the shared object
# without ODR violations. # without ODR violations.
cc_library( cc_library(
name = "get_compiler_ir_hdrs_only", name = "get_compiler_ir_hdrs",
hdrs = ["get_compiler_ir.h"], textual_hdrs = ["get_compiler_ir.h"],
visibility = [ visibility = [
":internal", ":internal",
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
@ -463,6 +472,23 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "get_compiler_ir_hdrs_only",
deps = [":get_compiler_ir_hdrs"],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
)
cc_library( cc_library(
name = "xla_kernel_creator", name = "xla_kernel_creator",
srcs = [ srcs = [
@ -842,9 +868,12 @@ tf_cc_test(
"partially_decluster_pass_test.cc", "partially_decluster_pass_test.cc",
"rearrange_function_argument_pass_test.cc", "rearrange_function_argument_pass_test.cc",
], ],
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value tags = [
# error. # TODO(b/141643254) Re-enable msan after fixing
tags = ["nomsan"] + tf_cuda_tests_tags(), # use-of-uninitialized-value error.
"nomsan",
"no_cuda_asan", # TODO(b/171317460): re-enable.
] + tf_cuda_tests_tags(),
deps = [ deps = [
":common", ":common",
":compilability_check_util", ":compilability_check_util",
@ -1075,15 +1104,3 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
visibility = ["//visibility:public"],
deps = [
":xla_cpu_device",
":xla_cpu_jit",
":xla_gpu_device",
":xla_gpu_jit",
],
)

View File

@ -167,8 +167,16 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5; jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags; // The `enable_mlir_bridge` flag allows the user to explicitly request that
mlir_flags->tf_mlir_enable_mlir_bridge = false; // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
//
// The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
// user has made an explicit request. That is, if this variable is set to
// true, the program honors the user's request as per `enable_mlir_bridge`; if
// it's set to false, the default behavior is used (which may run either
// bridge, on a per-graph basis).
bool enable_mlir_bridge = false;
bool enable_mlir_bridge_is_explicit = false;
auto setter_for_jitter_tensor_names = [](string sequence) { auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ','); jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
@ -217,12 +225,24 @@ void AllocateAndParseFlags() {
"The amount of jitter to introduce. This amount is added to each " "The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names."), "element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge", Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
&mlir_flags->tf_mlir_enable_mlir_bridge, "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")}); &enable_mlir_bridge_is_explicit)});
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
mlir_flags = new MlirCommonFlags;
if (!enable_mlir_bridge_is_explicit) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
} else if (enable_mlir_bridge) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
} else {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
}
} }
} // namespace } // namespace

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow { namespace tensorflow {
@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags {
// Flags for common MLIR configurations. // Flags for common MLIR configurations.
struct MlirCommonFlags { struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge; ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
}; };
// Return a pointer to the DumpGraphFlags struct; // Return a pointer to the DumpGraphFlags struct;

View File

@ -274,18 +274,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator); run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed()); run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default(); Env* env = Env::Default();
auto start_time = env->NowMicros(); auto start_time = env->NowMicros();
@ -522,18 +510,6 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator); run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed()); run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default(); Env* env = Env::Default();
auto start_time = env->NowMicros(); auto start_time = env->NowMicros();

View File

@ -283,25 +283,29 @@ Status XlaCompilationCache::CompileSingleOp(
const NodeDef& node_def = ctx->op_kernel().def(); const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
// TODO(b/155596779): Support TensorList args.
bool has_tensor_list_arg = bool has_tensor_list_arg =
absl::c_any_of(args, [](const XlaCompiler::Argument arg) { absl::c_any_of(args, [](const XlaCompiler::Argument arg) {
return arg.kind == XlaCompiler::Argument::kTensorList; return arg.kind == XlaCompiler::Argument::kTensorList;
}); });
const ConfigProto* config = ctx->function_library()->config_proto(); const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge(); // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
bool use_mlir = config && config->experimental().enable_mlir_bridge() &&
!has_tensor_list_arg &&
node_def.op() != "VarIsInitializedOp";
#ifdef LIBTPU_ON_GCE #ifdef LIBTPU_ON_GCE
if (use_mlir && has_tensor_list_arg) { if (use_mlir) {
LOG(WARNING) << "MLIR is not supported in this environment."; LOG(WARNING) << "MLIR is not supported in this environment.";
} }
return compiler->CompileGraph(compile_options, node_def.name(), return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result); std::move(graph), args, result);
#else #else
// TODO(b/155596779): Support TensorList args. if (!use_mlir) {
if (!use_mlir || !has_tensor_list_arg) {
return compiler->CompileGraph(compile_options, node_def.name(), return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result); std::move(graph), args, result);
} }
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info; GraphDebugInfo debug_info;
std::vector<std::string> control_rets; std::vector<std::string> control_rets;
if (result_dtypes.empty()) { if (result_dtypes.empty()) {

View File

@ -89,7 +89,8 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
XlaOpRegistry::RegisterCompilationKernels(); XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled. // Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge !=
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo> std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>

View File

@ -45,6 +45,7 @@ filegroup(
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
@ -122,8 +123,6 @@ gentbl(
tbl_outs = [ tbl_outs = [
("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"), ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"),
("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"),
("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"),
("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"),
], ],
tblgen = "@llvm-project//mlir:mlir-tblgen", tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td",
@ -150,6 +149,24 @@ gentbl(
], ],
tblgen = "@llvm-project//mlir:mlir-tblgen", tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td",
td_relative_includes = [
"include",
],
td_srcs = [":hlo_ops_td_files"],
)
gentbl(
name = "hlo_ops_base_structs_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"),
("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td",
td_relative_includes = [
"include",
],
td_srcs = [":hlo_ops_td_files"], td_srcs = [":hlo_ops_td_files"],
) )
@ -194,6 +211,63 @@ gentbl(
td_srcs = [":hlo_ops_td_files"], td_srcs = [":hlo_ops_td_files"],
) )
gentbl(
name = "lhlo_gpu_ops_structs_inc_gen",
compatible_with = get_compatible_with_cloud(),
strip_include_prefix = "include",
tbl_outs = [
("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"),
("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td",
td_relative_includes = [
"include",
],
td_srcs = [
":hlo_ops_td_files",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td",
],
)
cc_library(
name = "lhlo_gpu_ops_structs",
srcs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc",
"lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h",
],
includes = ["include"],
deps = [
":lhlo_gpu_ops_structs_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
gentbl(
name = "lhlo_gpu_ops_inc_gen",
compatible_with = get_compatible_with_cloud(),
strip_include_prefix = "include",
tbl_outs = [
("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"),
("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td",
td_relative_includes = [
"include",
],
td_srcs = [
":hlo_ops_td_files",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td",
],
)
#TODO(aminim): revisit the naming and grouping of these rules post-move. #TODO(aminim): revisit the naming and grouping of these rules post-move.
gentbl( gentbl(
name = "canonicalize_inc_gen", name = "canonicalize_inc_gen",
@ -251,6 +325,23 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "hlo_ops_base_structs",
srcs = [
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc",
"lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h",
],
includes = ["include"],
deps = [
":hlo_ops_base_structs_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
cc_library( cc_library(
name = "convert_op_folder", name = "convert_op_folder",
srcs = ["lib/utils/convert_op_folder.cc"], srcs = ["lib/utils/convert_op_folder.cc"],
@ -284,6 +375,7 @@ cc_library(
":chlo_ops_inc_gen", ":chlo_ops_inc_gen",
":convert_op_folder", ":convert_op_folder",
":hlo_ops_base_inc_gen", ":hlo_ops_base_inc_gen",
":hlo_ops_base_structs",
":hlo_ops_inc_gen", ":hlo_ops_inc_gen",
":infer_fusibility_op_interface", ":infer_fusibility_op_interface",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
@ -314,6 +406,7 @@ cc_library(
includes = ["include"], includes = ["include"],
deps = [ deps = [
":hlo_ops_base_inc_gen", ":hlo_ops_base_inc_gen",
":hlo_ops_base_structs",
":lhlo_ops_inc_gen", ":lhlo_ops_inc_gen",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
@ -330,6 +423,39 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "lhlo_gpu",
srcs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc",
"lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h",
],
includes = ["include"],
deps = [
":hlo",
":hlo_ops_base_structs",
":infer_fusibility_op_interface",
":lhlo_gpu_ops_inc_gen",
":lhlo_gpu_ops_structs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CopyOpInterface",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:ViewLikeInterface",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "hlo_dialect_registration", name = "hlo_dialect_registration",
srcs = ["lib/Dialect/mhlo/IR/init.cc"], srcs = ["lib/Dialect/mhlo/IR/init.cc"],
@ -337,6 +463,7 @@ cc_library(
deps = [ deps = [
":hlo", ":hlo",
":lhlo", ":lhlo",
":lhlo_gpu",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -385,6 +512,7 @@ cc_library(
":lhlo", ":lhlo",
":map_hlo_to_lhlo_op", ":map_hlo_to_lhlo_op",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
) )
@ -522,6 +650,7 @@ cc_library(
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:ViewLikeInterface",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -878,6 +1007,7 @@ cc_binary(
":all_passes", ":all_passes",
":hlo", ":hlo",
":lhlo", ":lhlo",
":lhlo_gpu",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",

View File

@ -25,7 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace)
endfunction() endfunction()
add_mlir_hlo_dialect(chlo_ops chlo) add_mlir_hlo_dialect(chlo_ops chlo)
add_mlir_hlo_dialect(hlo_ops mhlo)
add_mlir_hlo_dialect(lhlo_ops lmhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo)
set(LLVM_TARGET_DEFINITIONS hlo_ops.td)
mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
mlir_tablegen(hlo_ops.cc.inc -gen-op-defs)
mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRhlo_opsIncGen)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls)
mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td)
mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen)
add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen)
add_mlir_interface(infer_fusibility_op_interface) add_mlir_interface(infer_fusibility_op_interface)

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
@ -32,7 +33,7 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
// clang-format off // clang-format off
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
// clang-format on // clang-format on

View File

@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect {
let name = "mhlo";
let cppNamespace = "::mlir::mhlo";
}
class HLO_Op<string mnemonic, list<OpTrait> traits> : class HLO_Op<string mnemonic, list<OpTrait> traits> :
Op<HLO_Dialect, mnemonic, traits> { Op<HLO_Dialect, mnemonic, traits> {
// Whether this operation has a custom conversion to HLO or not. // Whether this operation has a custom conversion to HLO or not.
@ -136,8 +131,8 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
} }
LogicalResult reifyReturnTypeShapes( LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(), return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes); &reifiedReturnShapes);
} }
bool inferInputOutputShapeEquality(int input, int output) { bool inferInputOutputShapeEquality(int input, int output) {
return true; return true;
@ -153,7 +148,7 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape], [NoSideEffect, SameOperandsAndResultShape],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value operand" "Value operand"
>]; >];
} }
@ -168,8 +163,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp<
BASE_HLO_ConvertOp { BASE_HLO_ConvertOp {
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &, OperationState &tblgen_state, Value operand, " "Value operand, Type result_element_ty"
"Type result_element_ty"
>]; >];
let hasFolder = 1; let hasFolder = 1;
@ -247,7 +241,9 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
} }
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp {
let hasFolder = 1;
}
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
@ -293,8 +289,8 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
} }
LogicalResult reifyReturnTypeShapes( LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(), return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes); &reifiedReturnShapes);
} }
bool inferInputsShapeEquality(int lhs, int rhs) { bool inferInputsShapeEquality(int lhs, int rhs) {
return true; return true;
@ -458,7 +454,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
let arguments = (ins let arguments = (ins
HLO_TensorOrTuple:$operand, HLO_TensorOrTuple:$operand,
HLO_Token:$token, HLO_Token:$token,
ChannelHandle<HLO_Dialect>:$channel_id, ChannelHandle:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
); );
@ -483,7 +479,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
let arguments = (ins let arguments = (ins
HLO_Token:$token, HLO_Token:$token,
ChannelHandle<HLO_Dialect>:$channel_id, ChannelHandle:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
); );
@ -587,7 +583,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups, I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id OptionalAttr<ChannelHandle>:$channel_id
); );
let regions = (region SizedRegion<1>:$computation); let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
@ -959,15 +955,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp {
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
]> {
let description = "Structure of dimension information for dot product";
}
def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp {
let arguments = (ins let arguments = (ins
HLO_Tensor:$lhs, HLO_Tensor:$lhs,
@ -1029,14 +1016,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
StructFieldAttr<"start_index_map", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for gather";
}
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
@ -1114,7 +1093,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
HLO_Tensor:$operand, HLO_Tensor:$operand,
HLO_Tensor:$scatter_indices, HLO_Tensor:$scatter_indices,
HLO_Tensor:$updates, HLO_Tensor:$updates,
ScatterDimensionNumbers<HLO_Dialect>:$scatter_dimension_numbers, ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted, DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices DefaultValuedAttr<BoolAttr, "false">:$unique_indices
); );
@ -1124,6 +1103,8 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
let hasFolder = 1;
} }
// TODO(jpienaar): Add broadcastable trait. // TODO(jpienaar): Add broadcastable trait.
@ -1220,6 +1201,8 @@ def HLO_PadOp: HLO_Op<"pad",
// TODO(b/129422361): PadOp has a custom constructor for HLO. // TODO(b/129422361): PadOp has a custom constructor for HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
let hasFolder = 1;
} }
def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp {

View File

@ -18,6 +18,13 @@ limitations under the License.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def HLO_Dialect : Dialect {
let name = "mhlo";
let cppNamespace = "::mlir::mhlo";
}
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
// TODO(hinsu): Use signed integers instead of signless integer which is being // TODO(hinsu): Use signed integers instead of signless integer which is being
@ -614,15 +621,6 @@ class BASE_HLO_CaseOp {
// XLA parallelism related op definitions. // XLA parallelism related op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
class BASE_HLO_ReplicaIdOp { class BASE_HLO_ReplicaIdOp {
string summary = "ReplicaId operator"; string summary = "ReplicaId operator";
@ -712,6 +710,7 @@ def HLO_PrecisionConfigAttr:
OptionalAttr< OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>; TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions. // Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1011,21 +1010,6 @@ class BASE_HLO_ConcatenateOp {
// Common convolution attributes // Common convolution attributes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class ConvDimensionNumbersBase<Dialect dialect>
: StructAttr<"ConvDimensionNumbers", dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
class ConvolutionAttributes<Dialect dialect> { class ConvolutionAttributes<Dialect dialect> {
dag attributes = (ins dag attributes = (ins
// Default value: one for each of the spatial dimension. // Default value: one for each of the spatial dimension.
@ -1036,7 +1020,7 @@ class ConvolutionAttributes<Dialect dialect> {
OptionalAttr<I64ElementsAttr>:$lhs_dilation, OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension. // Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation, OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbersBase<dialect>:$dimension_numbers, ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count, I64Attr:$feature_group_count,
I64Attr:$batch_group_count, I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config HLO_PrecisionConfigAttr:$precision_config
@ -1164,15 +1148,6 @@ class BASE_HLO_ReshapeOp {
}]; }];
} }
class ScatterDimensionNumbers<Dialect dialect> : StructAttr<
"ScatterDimensionNumbers", dialect, [
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for scatter";
}
class BASE_HLO_ScatterOp { class BASE_HLO_ScatterOp {
string summary = "Scatter operator"; string summary = "Scatter operator";

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines structures used in MHLO and LMHLO.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef HLO_OPS_BASE_STRUCTS
#define HLO_OPS_BASE_STRUCTS
//===----------------------------------------------------------------------===//
// Dot dimensions enum definitions.
//===----------------------------------------------------------------------===//
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
]> {
let description = "Structure of dimension information for dot product";
}
def ScatterDimensionNumbers : StructAttr<
"ScatterDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for scatter";
}
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
StructFieldAttr<"start_index_map", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for gather";
}
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
#endif // HLO_OPS_BASE_STRUCTS

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the LHLO dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
class OpBuilder;
} // namespace mlir
namespace mlir {
namespace lmhlo_gpu {
class LmhloGpuDialect : public Dialect {
public:
explicit LmhloGpuDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "lmhlo_gpu"; }
};
} // namespace lmhlo_gpu
} // end namespace mlir
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_

View File

@ -0,0 +1,210 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for LHMLO level GPU operations.
// Because these are LMHLO level operations, they operate on memrefs.
#ifndef LHLO_GPU_OPS
#define LHLO_GPU_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td"
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LHLO_GPU_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
// Type for scratch buffers used by GPU library calls (memref<?xi8>)
def UntypedBuffer : MemRefRankOf<[I8], [1]>;
// Cholesky info output buffer type.
def I32Buffer : MemRefOf<[I32]>;
//===----------------------------------------------------------------------===//
// LMHLO ops representing batch norm library functions.
//===----------------------------------------------------------------------===//
// Note: these are semantically different from similar LHLO as the GPU library
// calls generate or consume standard deviation, whereas LHLO ops generate or
// consume variance (= std-dev ^ 2).
def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">,
BASE_HLO_BatchNormInferenceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index);
}
def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_stddev,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
//===----------------------------------------------------------------------===//
// LMHLO ops representing convolution library functions.
//===----------------------------------------------------------------------===//
def ActivationModeNone : StrEnumAttrCase<"None">;
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
def ActivationModeTanh : StrEnumAttrCase<"Relu">;
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
def ActivationAttr : StrEnumAttr<"Activation",
"Activation applied with fused convolution",
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
ActivationModeBandPass]>;
def GpuConvolutionAttributes {
dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
(ins F64Attr:$result_scale),
(ins ConvolutionBackendConfigAttr:$backend_config));
}
def GpuFusedConvolutionAttributes {
dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
(ins F64Attr:$result_scale,
ActivationAttr:$activation_mode,
F64Attr:$side_input_scale),
(ins ConvolutionBackendConfigAttr:$backend_config));
}
def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
}
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_input,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
}
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_filter,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
}
// output = activation(result_scale * conv(input, filter) +
// side_input * side_input_scale +
// bias)
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$side_input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuFusedConvolutionAttributes.attributes);
}
//===----------------------------------------------------------------------===//
// LMHLO ops representing other library functions.
//===----------------------------------------------------------------------===//
// output = alpha * (lhs * rhs)
// Verify: beta = 0.0
def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha,
I64Attr:$batch_size,
I64Attr:$algorithm);
}
// output = alpha(lhs * rhs) + beta * bias
def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha,
F64Attr:$beta,
I64Attr:$batch_size,
I64Attr:$algorithm);
}
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
Arg<I32Buffer, "", [MemWrite]>:$info,
BoolAttr:$is_upper);
}
#endif // LHLO_GPU_OPS

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// We define the dialect here so that both structs and ops can refer to it.
#ifndef LHLO_GPU_OPS_BASE
#define LHLO_GPU_OPS_BASE
include "mlir/IR/OpBase.td"
def LHLO_GPU_Dialect : Dialect {
let name = "lmhlo_gpu";
let cppNamespace = "::mlir::lmhlo_gpu";
}
#endif // LHLO_GPU_OPS_BASE

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ==============================================================================*/
// This file defines structures used in the LMHLO_GPU dialect.
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_GPU_OPS_STRUCTS
#define LHLO_GPU_OPS_STRUCTS
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
LHLO_GPU_Dialect, [
StructFieldAttr<"algorithm", I64Attr>,
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
let description = "GPU Convolution backend configuration";
}
#endif // LHLO_GPU_OPS_STRUCTS

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file defines the operations used in the LXLA dialect. // This file defines the operations used in the LHLO dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
@ -33,11 +34,6 @@ limitations under the License.
namespace mlir { namespace mlir {
class OpBuilder; class OpBuilder;
} // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
namespace mlir {
namespace lmhlo { namespace lmhlo {
class LmhloDialect : public Dialect { class LmhloDialect : public Dialect {

View File

@ -592,6 +592,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs, Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
DotDimensionNumbers:$dot_dimension_numbers,
HLO_PrecisionConfigAttr:$precision_config, HLO_PrecisionConfigAttr:$precision_config,
Arg<LHLO_Buffer, "", [MemWrite]>:$output Arg<LHLO_Buffer, "", [MemWrite]>:$output
); );
@ -601,11 +602,8 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices, Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
I64Attr:$index_vector_dim, GatherDimensionNumbers:$dimension_numbers,
I64ElementsAttr:$offset_dims,
I64ElementsAttr:$slice_sizes, I64ElementsAttr:$slice_sizes,
I64ElementsAttr:$collapsed_slice_dims,
I64ElementsAttr:$start_index_map,
Arg<LHLO_Buffer, "", [MemWrite]>:$output Arg<LHLO_Buffer, "", [MemWrite]>:$output
); );
} }
@ -623,7 +621,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices, Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
Arg<LHLO_Buffer, "", [MemRead]>:$updates, Arg<LHLO_Buffer, "", [MemRead]>:$updates,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers, ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted, DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices DefaultValuedAttr<BoolAttr, "false">:$unique_indices
); );
@ -699,7 +697,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$replica_groups, I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout, DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id, OptionalAttr<ChannelHandle>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
); );
let regions = (region SizedRegion<1>:$computation); let regions = (region SizedRegion<1>:$computation);
@ -712,7 +710,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$source_target_pairs, I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id OptionalAttr<ChannelHandle>:$channel_id
); );
} }

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -96,7 +97,7 @@ template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> { struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types, Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<SupportedType>()) { if (element_type.isa<SupportedType>()) {
return b->template create<StdScalarOp>(loc, result_types, args, return b->template create<StdScalarOp>(loc, result_types, args,
mlir::None); mlir::None);
@ -120,7 +121,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
@ -130,8 +131,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
Value lhs = args[0]; Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>(); auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval = Value zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
}
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge, auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
lhs, zero_intval); lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
@ -196,7 +200,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0]; const auto& lhs = args[0];
const auto& rhs = args[1]; const auto& rhs = args[1];
Type element_type = lhs.getType(); Type element_type = getElementTypeOrSelf(lhs.getType());
if (element_type.isSignlessInteger()) { if (element_type.isSignlessInteger()) {
Optional<CmpIPredicate> predicate = Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(comparison_direction); getCmpPredicate<CmpIPredicate>(comparison_direction);
@ -268,8 +272,8 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type sourceType = args.front().getType(); Type sourceType = getElementTypeOrSelf(args.front().getType());
Type targetType = result_types.front(); Type targetType = getElementTypeOrSelf(result_types.front());
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None); return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
@ -390,7 +394,7 @@ struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
static Value map(Location loc, StringRef comparison_direction, static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<SupportedType>()) { if (element_type.isa<SupportedType>()) {
auto predicate = getCmpPredicate<Predicate>(comparison_direction); auto predicate = getCmpPredicate<Predicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction"); assert(predicate.hasValue() && "expected valid comparison direction");
@ -439,7 +443,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
@ -449,8 +453,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
Value lhs = args[0]; Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>(); auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval = Value zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
}
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
} }
return nullptr; return nullptr;
@ -461,11 +468,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto integer_type = element_type.dyn_cast<IntegerType>()) { if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// lmhlo.not(x) -> x ^ -1 // lmhlo.not(x) -> x ^ -1
auto all_ones = Value all_ones =
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones);
}
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
} }
return nullptr; return nullptr;
@ -493,26 +503,35 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto float_type = element_type.dyn_cast<FloatType>()) { if (auto float_type = element_type.dyn_cast<FloatType>()) {
bool ignored; bool ignored;
APFloat one_apfloat(1.0f); APFloat one_apfloat(1.0f);
one_apfloat.convert(float_type.getFloatSemantics(), one_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored); APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type); Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
}
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) { } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero = Value zero =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
loc, integer_type.getWidth() - 1, integer_type.getWidth()); loc, integer_type.getWidth() - 1, integer_type.getWidth());
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value one = Value one =
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
bitwidth_minus_one =
b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one);
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
}
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
} }
@ -583,6 +602,27 @@ struct HloOpToStdScalarOp {
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b); op.getLoc(), comparison_direction, result_types, args, b);
} }
// Implementation for LHLO ops except lmhlo::CompareOp.
template <typename LhloOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b);
}
// Implementation for lmhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, lmhlo::CompareOp>::value>>
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
loc, comparison_direction, result_types, args, b);
}
}; };
} // namespace lmhlo } // namespace lmhlo

View File

@ -27,7 +27,6 @@ namespace mlir {
class LLVMTypeConverter; class LLVMTypeConverter;
class LowerToLLVMOptions; class LowerToLLVMOptions;
class OwningRewritePatternList; class OwningRewritePatternList;
class BufferAssignmentPlacer;
// Populates a collection of rewrite patterns to realize element-wise operations // Populates a collection of rewrite patterns to realize element-wise operations
// on ranked tensors where possible. // on ranked tensors where possible.
@ -56,9 +55,9 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx); MLIRContext *ctx);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect. // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern( void populateHLOToLHLOConversionPattern(MLIRContext *context,
MLIRContext *context, BufferAssignmentTypeConverter *converter, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect. // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context, void populateHLOToLinalgConversionPattern(MLIRContext *context,

View File

@ -43,6 +43,7 @@ add_mlir_library(MhloInferFusibilityOpInterface
add_mlir_dialect_library(MhloDialect add_mlir_dialect_library(MhloDialect
hlo_ops.cc hlo_ops.cc
hlo_ops_base_structs.cc
DEPENDS DEPENDS
MLIRhlo_opsIncGen MLIRhlo_opsIncGen
@ -66,6 +67,15 @@ add_mlir_dialect_library(LmhloDialect
) )
target_link_libraries(LmhloDialect PUBLIC MLIRIR) target_link_libraries(LmhloDialect PUBLIC MLIRIR)
add_mlir_dialect_library(LmhloGPUDialect
lhlo_gpu_ops.cc
lhlo_gpu_ops_structs.cc
DEPENDS
MLIRlhlo_gpu_opsIncGen
)
target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR)
add_mlir_dialect_library(MhloRegisterDialects add_mlir_dialect_library(MhloRegisterDialects
init.cc init.cc
@ -73,10 +83,12 @@ DEPENDS
MLIRchlo_opsIncGen MLIRchlo_opsIncGen
MLIRhlo_opsIncGen MLIRhlo_opsIncGen
MLIRlhlo_opsIncGen MLIRlhlo_opsIncGen
MLIRlhlo_gpu_opsIncGen
) )
target_link_libraries(MhloRegisterDialects target_link_libraries(MhloRegisterDialects
PUBLIC PUBLIC
ChloDialect ChloDialect
MhloDialect MhloDialect
LmhloDialect LmhloDialect
LmhloGPUDialect
) )

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h" #include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
@ -62,8 +63,6 @@ namespace mlir {
#include "hlo_patterns.cc.inc" #include "hlo_patterns.cc.inc"
} // namespace mlir } // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
@ -1054,6 +1053,9 @@ LogicalResult ConcatenateOp::inferReturnTypes(
return success(); return success();
} }
if (first_type.getRank() == 0)
return emitOptionalError(location, "rank-0 values cannot be concatenated");
auto out_shape = llvm::to_vector<6>(first_type.getShape()); auto out_shape = llvm::to_vector<6>(first_type.getShape());
// Determine what the non-concatenate dimensions should be. // Determine what the non-concatenate dimensions should be.
@ -1785,6 +1787,61 @@ static LogicalResult Verify(PadOp op) {
return success(); return success();
} }
OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// If all padding is zero then it is an identity pad.
auto is_zero = [](const APInt& i) { return i == 0; };
if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) &&
llvm::all_of(edge_padding_high().getIntValues(), is_zero) &&
llvm::all_of(interior_padding().getIntValues(), is_zero))
return operand();
// If any padding is negative then it isn't supported by the folder (yet).
auto is_negative = [](const APInt& i) { return i.slt(0); };
if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) &&
llvm::all_of(edge_padding_high().getIntValues(), is_negative) &&
llvm::all_of(interior_padding().getIntValues(), is_negative))
return {};
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
if (!input || !input.getType().hasRank() || !padding || !return_type ||
!return_type.hasStaticShape())
return {};
// Fill the full result tensor with the padding value.
llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
padding.getValue({}));
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
++index[i];
if (index[i] < shape[i]) return true;
index[i] = 0;
}
return false;
};
// Iterate over all elements of the input tensor and copy it to the correct
// location in the output tensor.
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
do {
uint64_t linear_index = 0;
uint64_t linear_index_multiplyer = 1;
for (int64_t i = index.size() - 1; i >= 0; --i) {
linear_index +=
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
index[i] *
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
linear_index_multiplyer;
linear_index_multiplyer *= return_type.getShape()[i];
}
result[linear_index] = input.getValue(index);
} while (next_index(index, input.getType().getShape()));
return DenseElementsAttr::get(return_type, result);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReshapeOp // ReshapeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1931,6 +1988,14 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return DenseElementsAttr::get(type, values); return DenseElementsAttr::get(type, values);
} }
struct round {
APFloat operator()(const APFloat& f) {
APFloat r = f;
r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
return r;
}
};
#define UNARY_FOLDER(Op, Func) \ #define UNARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \ OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \ if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
@ -1940,7 +2005,15 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return {}; \ return {}; \
} }
#define UNARY_FOLDER_FLOAT(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
return {}; \
}
UNARY_FOLDER(NegOp, std::negate); UNARY_FOLDER(NegOp, std::negate);
UNARY_FOLDER_FLOAT(RoundOp, round);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BinaryOps // BinaryOps
@ -2645,6 +2718,145 @@ OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
} }
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
ArrayRef<Attribute> inputs) {
if (region.getNumArguments() != inputs.size()) return {};
llvm::DenseMap<Value, Attribute> values;
values.reserve(region.getNumArguments());
for (auto it : llvm::zip(region.getArguments(), inputs)) {
values.try_emplace(std::get<0>(it), std::get<1>(it));
}
for (auto& op : region.getOps()) {
llvm::SmallVector<Attribute, 4> inputs;
for (auto& operand : op.getOpOperands()) {
inputs.push_back(values.lookup(operand.get()));
}
if (isa<ReturnOp>(op)) return inputs;
llvm::SmallVector<OpFoldResult, 4> results;
if (failed(op.fold(inputs, results))) return {};
for (auto it : llvm::zip(op.getResults(), results)) {
if (!std::get<1>(it).is<Attribute>()) return {};
values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
}
}
return {};
}
OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
if (!base || !index || !update) return {};
auto base_type = base.getType().dyn_cast<RankedTensorType>();
auto index_type = index.getType().dyn_cast<RankedTensorType>();
auto update_type = update.getType().dyn_cast<RankedTensorType>();
if (!base_type || !index_type || !update_type) return {};
// Add the virtual trailing dimension of size 1 if index_vector_dim equals to
// index_type.rank.
const int64_t index_vector_dim =
scatter_dimension_numbers().index_vector_dim().getInt();
if (index_vector_dim == index_type.getRank()) {
auto index_shape = index_type.getShape().vec();
index_shape.push_back(1);
index_type =
RankedTensorType::get(index_shape, index_type.getElementType());
index = index.reshape(index_type).cast<DenseIntElementsAttr>();
}
// Increment the multi-dimensional index vector based on the limits for each
// dimension specified by shape and returns false if the index rolled around
// with true otherwise.
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
++index[i];
if (index[i] < shape[i]) return true;
index[i] = 0;
}
return false;
};
// Iterate over all elements of the update tensor, then find the corresponding
// value in the indices tensor to determine which location we have to update
// in the base/result tensor.
llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
llvm::SmallVector<uint64_t, 8> index_index;
index_index.reserve(index_type.getRank());
llvm::SmallVector<uint64_t, 8> base_index;
base_index.reserve(base_type.getRank());
do {
// Compute the index for the slice of the indices tensor for this update
// value.
index_index.clear();
if (index_vector_dim == 0) index_index.push_back(0);
for (int64_t i = 0; i < update_index.size(); ++i) {
if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0)
index_index.push_back(update_index[i]);
if (index_index.size() == index_vector_dim) index_index.push_back(0);
}
// Compute the index for the given update value in the base tensor.
base_index.assign(base_type.getRank(), 0);
uint64_t index_count = index_type.getShape()[index_vector_dim];
for (uint64_t i = 0; i < index_count; ++i) {
uint64_t operand_dim = scatter_dimension_numbers()
.scatter_dims_to_operand_dims()
.getValue<APInt>({i})
.getSExtValue();
index_index[index_vector_dim] = i;
base_index[operand_dim] +=
index.getValue<APInt>(index_index).getSExtValue();
}
uint64_t update_window_dim_index = 0;
for (uint64_t i = 0; i < base_index.size(); ++i) {
if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i))
continue;
base_index[i] +=
update_index[scatter_dimension_numbers()
.update_window_dims()
.getValue<APInt>({update_window_dim_index})
.getSExtValue()];
update_window_dim_index++;
}
// Compute the linear index for the index into the base tensor.
int64_t linear_base_index = 0;
int64_t linear_base_index_multiplyer = 1;
for (int64_t i = base_index.size() - 1; i >= 0; --i) {
// Out of bound index have backend specific behaviour so avoid folding it.
if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
return {};
linear_base_index += base_index[i] * linear_base_index_multiplyer;
linear_base_index_multiplyer *= base_type.getShape()[i];
}
// Evaluate update computation and update the value with the newly computed
// attribute in the base tensor.
auto lhs = DenseElementsAttr::get(
RankedTensorType::get({}, base_type.getElementType()),
results[linear_base_index]);
auto rhs = DenseElementsAttr::get(
RankedTensorType::get({}, base_type.getElementType()),
update.getValue<Attribute>(update_index));
auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
if (new_value.size() != 1 || !new_value[0]) return {};
results[linear_base_index] =
new_value[0].cast<DenseElementsAttr>().getValue<Attribute>({});
} while (next_index(update_index, update_type.getShape()));
return DenseElementsAttr::get(base_type, results);
}
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"

View File

@ -18,15 +18,13 @@ limitations under the License.
include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
// Canonicalization patterns. // Canonicalization patterns.
def DynamicBroadcastToOwnShape_1 : Pat< def DynamicBroadcastToOwnShape_1 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0, (HLO_DynamicBroadcastInDimOp:$op $x,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr), (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x)), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; (replaceWithValue $x)>;
def DynamicBroadcastToOwnShape_2 : Pat< def DynamicBroadcastToOwnShape_2 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr), (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; (replaceWithValue $x)>;

View File

@ -15,13 +15,15 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h" #include "mlir-hlo/Dialect/mhlo/IR/register.h"
void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry &registry) { void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry &registry) {
// clang-format off // clang-format off
registry.insert<mlir::chlo::HloClientDialect, registry.insert<mlir::chlo::HloClientDialect,
mlir::mhlo::MhloDialect,
mlir::lmhlo::LmhloDialect, mlir::lmhlo::LmhloDialect,
mlir::mhlo::MhloDialect>(); mlir::lmhlo_gpu::LmhloGpuDialect>();
// clang-format on // clang-format on
} }

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the LMHLO GPU dialect.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace mlir {
namespace lmhlo_gpu {
LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<LmhloGpuDialect>()) {
addOperations<
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"
>();
}
// TODO(jurahul): Add verification for operand shapes and ranks.
} // namespace lmhlo_gpu
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"

View File

@ -0,0 +1,17 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"

View File

@ -42,7 +42,7 @@ namespace mhlo {
namespace { namespace {
template <typename T> template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>; using BaseOpConversion = OpConversionPattern<T>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand, Value shape_operand,
@ -126,6 +126,60 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
} }
}; };
// This specialization exists so that LMHLO's Dot can be given a specific set of
// dimension numbers, when lowering from MHLO's Dot, which does not have
// dimension numbers (it uses DotGeneral for this generalized notion of dot
// products). When these two dialects are in sync with respect to the
// Dot/DotGeneral issue, this specialization should be deleted.
template <>
class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
public:
using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DotOp hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
// TODO(silvasean): Move this helper to MLIR core.
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
rewriter.getIntegerType(64));
return DenseIntElementsAttr::get(type, integers);
};
auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
// MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
auto dimension_numbers = mhlo::DotDimensionNumbers::get(
make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
make_elements_attr({0}), rewriter.getContext());
dotOp.dot_dimension_numbersAttr(dimension_numbers);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> { : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public: public:
@ -236,6 +290,43 @@ struct HloToLhloDynamicReshapeConverter
} }
}; };
struct HloToLhloDotGeneralOpConverter
: public BaseOpConversion<mhlo::DotGeneralOp> {
using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = dotGeneralOp.getOperation();
if (op->getResults().empty()) return failure();
OpResult result = op->getResults()[0];
RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
if (!resultType) return failure();
// The third buffer argument will be filled with what used to be the return
// type of the DotGeneral.
if (operands.size() != 2) return failure();
std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
if (resultType.hasStaticShape()) {
bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
bufferArgs[2] = InsertDynamicAllocAndDealloc(
op->getLoc(), result, results_shape.front(), &rewriter);
}
rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
op->getAttrs());
rewriter.replaceOp(op, bufferArgs[2]);
return success();
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> { struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
public: public:
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion; using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
@ -433,7 +524,7 @@ struct HloLegalizeToLhlo
target.addLegalOp<TensorFromElementsOp>(); target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<mhlo::MhloDialect>(); target.addIllegalDialect<mhlo::MhloDialect>();
BufferAssignmentTypeConverter converter; BufferizeTypeConverter converter;
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); }; auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs(); auto inputs = op.getType().getInputs();
@ -456,16 +547,16 @@ struct HloLegalizeToLhlo
}); });
auto kind = results_escape_function auto kind = results_escape_function
? BufferAssignmentTypeConverter::KeepAsFunctionResult ? BufferizeTypeConverter::KeepAsFunctionResult
: BufferAssignmentTypeConverter::AppendToArgumentsList; : BufferizeTypeConverter::AppendToArgumentsList;
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>( converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
kind); kind);
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind); converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferAssignmentOpConversionPatterns< populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, converter, lmhlo::CopyOp>(
patterns); &context, converter, patterns);
populateShapeTypeConversionPatterns(&context, converter, patterns); populateShapeTypeConversionPatterns(&context, converter, patterns);
if (failed(applyPartialConversion(getOperation(), target, patterns))) if (failed(applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure(); signalPassFailure();
@ -480,11 +571,12 @@ struct HloLegalizeToLhlo
}; };
} // namespace } // namespace
void populateHLOToLHLOConversionPattern( void populateHLOToLHLOConversionPattern(MLIRContext* context,
MLIRContext* context, BufferAssignmentTypeConverter* converter, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
HloToLhloDotGeneralOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter, HloToLhloDynamicReshapeConverter,
HloToLhloOpConverter<mhlo::AbsOp>, HloToLhloOpConverter<mhlo::AbsOp>,
@ -531,7 +623,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloReturnOpConverter, HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter HloToLhloTensorStoreOpConverter
>(context, *converter); >(context);
// clang-format on // clang-format on
} }

View File

@ -192,7 +192,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
lmhlo::ConvOp op, ArrayRef<Value> args, lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information. // Check validity of dimension information.
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = if (const mhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) { op.dimension_numbers()) {
const int inputSpatialRank = const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions()); llvm::size(dimensionNumbers.input_spatial_dimensions());

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
@ -73,6 +74,24 @@ class LhloFuseLinalgPass
result_buffers.insert(operand); result_buffers.insert(operand);
} }
} }
// Resolve aliasing operations (like casts) on the result to identify
// results. This only handles escaping results.
// TODO(herhut): Use BufferizeAliasAnalysis for this.
llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
result_buffers.end());
while (!worklist.empty()) {
Value result = worklist.pop_back_val();
auto definingOp = result.getDefiningOp();
if (!definingOp) {
continue;
}
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
auto alias = viewLike.getViewSource();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
}
}
MLIRContext* ctx = func.getContext(); MLIRContext* ctx = func.getContext();
OpBuilder b(func); OpBuilder b(func);
OperationFolder folder(ctx); OperationFolder folder(ctx);

View File

@ -59,6 +59,20 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
return failure(); return failure();
} }
// We don't currently support batching dimensions, or multiple contraction
// dimensions.
mhlo::DotDimensionNumbers dot_dimension_numbers =
op.dot_dimension_numbers();
if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 ||
dot_dimension_numbers.rhs_batching_dimensions().size() > 0)
return failure();
if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 ||
*dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 ||
dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 ||
*dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) {
return failure();
}
LogicalResult map_status = success(); LogicalResult map_status = success();
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]}, SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},

View File

@ -81,6 +81,14 @@ func @remainder_fold_float() -> tensor<4xf32> {
return %2 : tensor<4xf32> return %2 : tensor<4xf32>
} }
// CHECK-LABEL: round_fold
func @round_fold() -> tensor<4xf32> {
%0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32>
%1 = "mhlo.round_nearest_afz"(%0) : (tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
// CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 3.000000e+00]>
}
// CHECK-LABEL: max_scalar_fold // CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> { func @max_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64> %0 = mhlo.constant dense<7> : tensor<4xi64>
@ -1167,3 +1175,291 @@ func @not_fold_sqrt_neg_constants() -> tensor<4xf32> {
// CHECK: mhlo.sqrt // CHECK: mhlo.sqrt
return %1 : tensor<4xf32> return %1 : tensor<4xf32>
} }
// CHECK-LABEL: @tensor_flow_scatter_v1_update
func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[0, 2]> : tensor<2xi32>
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90]
// CHECK-SAME: ]> : tensor<3x3xi32>
}
// CHECK-LABEL: @tensor_flow_scatter_v2_update
func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[0, 2]> : tensor<2xi32>
%2 = constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<1> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>,
update_window_dims = dense<[0]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90]
// CHECK-SAME: ]> : tensor<3x3xi32>
}
// CHECK-LABEL: @tensor_flow_scatter_add
func @tensor_flow_scatter_add() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[0, 2]> : tensor<2xi32>
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
"mhlo.return"(%4) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99]
// CHECK-SAME: ]> : tensor<3x3xi32>
}
// CHECK-LABEL: @tensor_flow_scatter_repeated
func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[1, 1]> : tensor<2xi32>
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
"mhlo.return"(%4) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9]
// CHECK-SAME: ]> : tensor<3x3xi32>
}
// CHECK-LABEL: @tensor_flow_scatter_multiple_batch
func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32>
%2 = constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
"mhlo.return"(%4) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 2 : i64,
inserted_window_dims = dense<1> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104]
// CHECK-SAME: ]> : tensor<3x3xi32>
}
// CHECK-LABEL: @tensor_flow_scatter_nd
func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> {
%0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32>
%1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32>
%2 = constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32>
return %3 : tensor<3x3x2xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [-10, 10], [-2, 2], [-3, 3]
// CHECK-SAME: [-40, 40], [-5, 5], [-6, 6]
// CHECK-SAME: [-7, 7], [-8, 8], [-9, 9]
// CHECK-SAME: ]> : tensor<3x3x2xi32>
}
// CHECK-LABEL: @tensor_flow_scatter_nd_index_vector
func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> {
%0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32>
%1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32>
%2 = constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 0 : i64,
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32>
return %3 : tensor<3x3x2xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [-20, 20], [-10, 10], [-3, 3]
// CHECK-SAME: [-4, 4], [-5, 5], [-6, 6]
// CHECK-SAME: [-7, 7], [-8, 8], [-9, 9]
// CHECK-SAME: ]> : tensor<3x3x2xi32>
}
// CHECK-LABEL: @scatter_batch_dus
func @scatter_batch_dus() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32>
%2 = constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 0 : i64,
inserted_window_dims = dense<> : tensor<0xi64>,
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
update_window_dims = dense<[1, 2]> : tensor<2xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9]
// CHECK-SAME: ]> : tensor<3x3xi32>
}
// CHECK-LABEL: @scatter_no_update_window_dim
func @scatter_no_update_window_dim() -> tensor<3xi32> {
%0 = constant dense<[0, 1, 2]> : tensor<3xi32>
%1 = constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32>
%2 = constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
"mhlo.return"(%4) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 2 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<> : tensor<0xi64>
},
unique_indices = false
} : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32>
return %3 : tensor<3xi32>
// CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32>
}
// CHECK-LABEL: @scatter_negative_index
func @scatter_negative_index() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[0, -1]> : tensor<2xi32>
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9]
// CHECK-SAME: ]> : tensor<3x3xi32>
// CHECK: "mhlo.scatter"
}
// CHECK-LABEL: @scatter_out_of_bound
func @scatter_out_of_bound() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
%1 = constant dense<[1, 5]> : tensor<2xi32>
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
%3 = "mhlo.scatter"(%0, %1, %2) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<[1]> : tensor<1xi64>
},
unique_indices = false
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
return %3 : tensor<3x3xi32>
// CHECK: constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9]
// CHECK-SAME: ]> : tensor<3x3xi32>
// CHECK: "mhlo.scatter"
}
// CHECK-LABEL: @pad_identity_fold
func @pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<5x7xf32> {
%0 = constant dense<0.0> : tensor<f32>
%1 = "mhlo.pad"(%arg0, %0) {
edge_padding_low = dense<0> : tensor<2xi64>,
edge_padding_high = dense<0> : tensor<2xi64>,
interior_padding = dense<0> : tensor<2xi64>
} : (tensor<5x7xf32>, tensor<f32>) -> tensor<5x7xf32>
return %1 : tensor<5x7xf32>
// CHECK: return %arg0 : tensor<5x7xf32>
}
// CHECK-LABEL: @pad_fold
func @pad_fold() -> tensor<4x5xi32> {
%0 = constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32>
%1 = constant dense<1> : tensor<i32>
%3 = "mhlo.pad"(%0, %1) {
edge_padding_low = dense<[1, 0]> : tensor<2xi64>,
edge_padding_high = dense<[1, 2]> : tensor<2xi64>,
interior_padding = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi32>, tensor<i32>) -> tensor<4x5xi32>
return %3 : tensor<4x5xi32>
// CHECK: constant dense<[
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
// CHECK-SAME: ]> : tensor<4x5xi32>
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
// CHECK-LABEL: func @func_op_unranked_arg_result // CHECK-LABEL: func @func_op_unranked_arg_result
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {

View File

@ -1,5 +1,5 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
// BOTH-LABEL: func @attrs // BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -287,6 +287,28 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// ----- // -----
// BOTH-LABEL: func @gather
func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) {
%tensor_operand = tensor_load %operand : memref<13x7xf32>
%tensor_idxs = tensor_load %idxs : memref<5xi32>
%tensor_result =
"mhlo.gather"(%tensor_operand, %tensor_idxs)
{ dimension_numbers =
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
, index_vector_dim = 1 : i64
, offset_dims = dense<1> : tensor<1xi64>
, start_index_map = dense<0> : tensor<1xi64> }
, indices_are_sorted = false
, name = "gather.71"
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
: (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
// BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<5x7xf32>
return
}
// -----
// BOTH-LABEL: func @imag_dyn // BOTH-LABEL: func @imag_dyn
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) { func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>> %tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
@ -511,7 +533,13 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc // BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// dot_dimension_numbers = {
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
// rhs_batching_dimensions = dense<> : tensor<0xi64>,
// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}}
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "mhlo.dot"(%arg0, %arg0) %dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) // PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
@ -632,4 +660,4 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
shape.assuming_yield %7 : tensor<?xf16> shape.assuming_yield %7 : tensor<?xf16>
} }
return %2 : tensor<?xf16> return %2 : tensor<?xf16>
} }

View File

@ -3,7 +3,8 @@
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} #pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() : memref<6x6xf32> %temp_result = alloc() : memref<6x6xf32>
@ -73,7 +74,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
} }
%1 = alloc() : memref<100x10xf32> %1 = alloc() : memref<100x10xf32>
linalg.generic { linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]} iterator_types = ["parallel", "parallel"]}
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
outs(%1 : memref<100x10xf32>) { outs(%1 : memref<100x10xf32>) {
@ -83,7 +86,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
} }
dealloc %0 : memref<100x10xf32> dealloc %0 : memref<100x10xf32>
linalg.generic { linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]} iterator_types = ["parallel", "parallel"]}
ins(%1 : memref<100x10xf32>) ins(%1 : memref<100x10xf32>)
outs(%arg2 : memref<100x10xf32>) { outs(%arg2 : memref<100x10xf32>) {
@ -132,7 +136,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
// ----- // -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} #pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel",
"parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() : memref<6x6x6x6xf32> %temp_result = alloc() : memref<6x6x6x6xf32>
@ -190,7 +196,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
// ----- // -----
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} #pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32> %temp_result = alloc() : memref<6x6xf32>
@ -244,3 +251,51 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
// PLOOP: addf // PLOOP: addf
// PLOOP: linalg.generic // PLOOP: linalg.generic
// PLOOP: mulf // PLOOP: mulf
// -----
func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = lmhlo.reshape_memref_cast %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %2 : memref<*xf32>
}
// CHECK-LABEL: func @view_result
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: reshape_memref_cast
// TILED-LABEL: func @view_result
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: reshape_memref_cast
// PLOOP-LABEL: func @view_result
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: reshape_memref_cast

View File

@ -158,7 +158,14 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
// CHECK: return // CHECK: return
"lmhlo.dot"(%lhs, %rhs, %result) : "lmhlo.dot"(%lhs, %rhs, %result) {
dot_dimension_numbers = {
lhs_batching_dimensions = dense<> : tensor<0xi64>,
rhs_batching_dimensions = dense<> : tensor<0xi64>,
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
}
} :
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
return return
} }
@ -175,7 +182,14 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
// CHECK: return // CHECK: return
"lmhlo.dot"(%lhs, %rhs, %result) : "lmhlo.dot"(%lhs, %rhs, %result) {
dot_dimension_numbers = {
lhs_batching_dimensions = dense<> : tensor<0xi64>,
rhs_batching_dimensions = dense<> : tensor<0xi64>,
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
}
} :
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
return return
} }

View File

@ -621,10 +621,10 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]): // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16 // CHECK-NEXT: %[[C0:.*]] = constant 0 : i16
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16 // CHECK-NEXT: %[[C15:.*]] = constant 15 : i16
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16 // CHECK-NEXT: %[[C1:.*]] = constant 1 : i16
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16 // CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 // CHECK-NEXT: linalg.yield %[[RESULT]] : i16

View File

@ -0,0 +1,99 @@
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
// CHECK-LABEL: func @batch_norm_grad_memrefs
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
%grad_offset: memref<8xf32>) -> () {
"lmhlo_gpu.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
}
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"lmhlo_gpu.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return
}
// CHECK-LABEL: func @batch_norm_training_memrefs
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
%batch_var: memref<8xf32>) -> () {
"lmhlo_gpu.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
}
// CHECK-LABEL: func @conv_forward
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
%scratch = alloc() : memref<32xi8>
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
{ dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2,3]> : tensor<2xi64>,
kernel_input_feature_dimension = 0 : i64,
kernel_output_feature_dimension = 1 : i64,
kernel_spatial_dimensions = dense<[2,3]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>},
window_strides = dense<[1, 1]> : tensor<2xi64>,
padding = dense<[0,0]> : tensor<2xi64>,
lhs_dilation = dense<[1,1]> : tensor<2xi64>,
rhs_dilation = dense<[1,1]> : tensor<2xi64>,
feature_group_count = 1,
batch_group_count = 1,
result_scale = 1.0,
backend_config = {algorithm=0, tensor_ops_enabled = true }
}
: (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> ()
return
}
// -----
// CHECK-LABEL: func @gemm
func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {
"lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = {
lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
batch_size = 1,
algorithm = 0}
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-LABEL: func @gemm_bias
func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
%bias: memref<5x5xf32>, %output:memref<5x5xf32>) {
"lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = {
lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
beta = 1.0,
batch_size = 1,
algorithm = 0}
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-LABEL: func @cholesky
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
%scratch = alloc() : memref<32xi8>
%info = alloc() : memref<32xi32>
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
return
}

View File

@ -328,6 +328,14 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<
// ----- // -----
func @concat_0D(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<2xi32> {
// expected-error@+1 {{rank-0 values cannot be concatenated}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: @concat_1D // CHECK-LABEL: @concat_1D
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/InitAllDialects.h" #include "mlir/InitAllDialects.h"
@ -31,6 +32,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::mhlo::MhloDialect>(); registry.insert<mlir::mhlo::MhloDialect>();
registry.insert<mlir::chlo::HloClientDialect>(); registry.insert<mlir::chlo::HloClientDialect>();
registry.insert<mlir::lmhlo::LmhloDialect>(); registry.insert<mlir::lmhlo::LmhloDialect>();
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
return failed( return failed(
mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));

View File

@ -326,6 +326,21 @@ static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
namespace { namespace {
// Helper struct that wraps inputs/outputs of a single SignatureDef.
struct SignatureDefData {
// Note, we are using maps here to make order deterministic
// for easily testing only.
// Inputs defined in the signature def mapped to tensor names.
std::map<std::string, std::string> inputs;
// Outputs defined in the signature def mapped to tensor names.
std::map<std::string, std::string> outputs;
// Method name exported by the signature def.
std::string method_name;
// SignatureDef key.
std::string signature_def_key;
};
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
class Translator { class Translator {
public: public:
@ -334,16 +349,19 @@ class Translator {
// internal error. // internal error.
static Optional<std::string> Translate( static Optional<std::string> Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); bool emit_custom_ops, const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper);
private: private:
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper) OpOrArgNameMapper* op_or_arg_name_mapper)
: module_(module), : module_(module),
name_mapper_(*op_or_arg_name_mapper), name_mapper_(*op_or_arg_name_mapper),
builder_(kInitialBufferSize) { builder_(kInitialBufferSize),
saved_model_tags_(saved_model_tags) {
// The first buffer must be empty according to the schema definition. // The first buffer must be empty according to the schema definition.
empty_buffer_ = tflite::CreateBuffer(builder_); empty_buffer_ = tflite::CreateBuffer(builder_);
buffers_.push_back(empty_buffer_); buffers_.push_back(empty_buffer_);
@ -450,6 +468,17 @@ class Translator {
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>> Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector(); CreateMetadataVector();
// Builds and returns list of tfl.SignatureDef sections in the model.
Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
// Returns list of offsets for the passed 'items' in TensorMap structure
// inside the flatbuffer.
// 'items' is a map from tensor name in signatureDef to tensor name in
// the model.
std::vector<BufferOffset<tflite::TensorMap>> GetList(
const std::map<std::string, std::string>& items);
// Uses the tf.entry_function attribute (if set) to initialize the op to name // Uses the tf.entry_function attribute (if set) to initialize the op to name
// mapping. // mapping.
void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
@ -472,6 +501,8 @@ class Translator {
BufferOffset<tflite::Buffer> empty_buffer_; BufferOffset<tflite::Buffer> empty_buffer_;
std::vector<BufferOffset<tflite::Buffer>> buffers_; std::vector<BufferOffset<tflite::Buffer>> buffers_;
// Maps tensor name in the graph to the tensor index.
absl::flat_hash_map<std::string, int> tensor_index_map_;
// Maps op name to index of the corresponding OperatorCode in opcodes_ vector. // Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
absl::flat_hash_map<std::string, uint32_t> opcode_index_map_; absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
@ -490,6 +521,9 @@ class Translator {
// The failed ops during legalization. // The failed ops during legalization.
std::set<std::string> failed_flex_ops_; std::set<std::string> failed_flex_ops_;
std::set<std::string> failed_custom_ops_; std::set<std::string> failed_custom_ops_;
// Set of saved model tags, if any.
const std::unordered_set<std::string> saved_model_tags_;
}; };
std::string Translator::UniqueName(mlir::Value val) { std::string Translator::UniqueName(mlir::Value val) {
@ -1131,6 +1165,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
} }
tensor_index_map.insert({value, tensors.size()}); tensor_index_map.insert({value, tensors.size()});
tensor_index_map_[name] = tensors.size();
auto tensor_or = BuildTensor(value, name, buffers_.size()); auto tensor_or = BuildTensor(value, name, buffers_.size());
if (!tensor_or) return false; if (!tensor_or) return false;
tensors.push_back(*tensor_or); tensors.push_back(*tensor_or);
@ -1286,6 +1321,149 @@ Translator::CreateMetadataVector() {
return builder_.CreateVector(metadata); return builder_.CreateVector(metadata);
} }
// Helper method that returns list of all strings in a StringAttr identified
// by 'attr_key' and values are separated by a comma.
llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
mlir::DictionaryAttr attr, const std::string& attr_key) {
llvm::SmallVector<llvm::StringRef, 2> result;
if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
str.getValue().split(result, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
}
return result;
}
// Helper method that return list of string for all the StringAttr in the
// Attribute identified by 'attr_name'.
std::vector<std::string> GetStringsFromDictionaryAttr(
const llvm::SmallVector<mlir::MutableDictionaryAttr, 4>& dict_attrs,
const std::string& attr_name) {
std::vector<std::string> result;
for (const auto& arg_attr : dict_attrs) {
auto attrs = arg_attr.getAttrs();
for (const auto attr : attrs) {
if (attr.first.str() == attr_name) {
auto array_attr = attr.second.dyn_cast_or_null<mlir::ArrayAttr>();
if (!array_attr || array_attr.empty()) continue;
auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
if (!string_attr) continue;
result.push_back(string_attr.getValue().str());
}
}
}
return result;
}
std::vector<SignatureDefData> BuildSignaturedef(
FuncOp main_op, const std::string& saved_model_tag) {
static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
static const char kEntryFunctionAttributes[] = "tf.entry_function";
// Fetch inputs and outputs from the signature.
llvm::SmallVector<mlir::MutableDictionaryAttr, 4> arg_attrs, res_attrs;
main_op.getAllArgAttrs(arg_attrs);
main_op.getAllResultAttrs(res_attrs);
std::vector<std::string> sig_def_inputs =
GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
std::vector<std::string> sig_def_outputs =
GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
// If no defined saved model signature, then return empty list.
// This can happen when we are converting model not from SavedModel.
if (sig_def_inputs.empty() || sig_def_outputs.empty()) return {};
// Fetch function inputs and outputs tensor names.
auto dict_attr =
main_op.getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
if (!dict_attr) return {};
// Get Input and output tensor names from attribute.
llvm::SmallVector<llvm::StringRef, 2> input_names =
GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
llvm::SmallVector<llvm::StringRef, 2> output_names =
GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
// Verify input size match the number of arguments.
if (input_names.size() != main_op.getNumArguments()) {
main_op.emitWarning() << "invalid entry function specification";
return {};
}
// Verify output size match the number of arguments.
auto term = main_op.back().getTerminator();
if (output_names.size() != term->getNumOperands()) {
main_op.emitWarning() << "output names (" << output_names.size()
<< ") != terminator operands ("
<< term->getNumOperands() << ")";
return {};
}
// Verify number of tensors for inputs and outputs matches size
// of the list in the signature def.
if (input_names.size() != sig_def_inputs.size() ||
output_names.size() != sig_def_outputs.size()) {
main_op.emitWarning(
"Mismatch between signature def inputs/outputs and main function "
"arguments.");
return {};
}
// Exported method name.
auto exported_name =
main_op.getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
if (exported_name.empty()) {
main_op.emitError("Empty exported names for main Function");
return {};
}
// Fill the SignatureDefData container.
// We create vector of size 1 as TFLite now supports only 1 signatureDef.
std::vector<SignatureDefData> result(1);
for (int i = 0; i < input_names.size(); ++i) {
result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
}
for (int i = 0; i < output_names.size(); ++i) {
result[0].outputs[sig_def_outputs[i]] = output_names[i].str();
}
if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
result[0].method_name = name_attr.getValue().str();
result[0].signature_def_key = saved_model_tag;
return result;
}
std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
const std::map<std::string, std::string>& items) {
std::vector<BufferOffset<tflite::TensorMap>> result;
for (const auto& item : items) {
auto name_buf = builder_.CreateString(item.first);
tflite::TensorMapBuilder tensor_map_builder(builder_);
tensor_map_builder.add_name(name_buf);
tensor_map_builder.add_tensor_index(tensor_index_map_[item.second]);
result.push_back(tensor_map_builder.Finish());
}
return result;
}
Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
Translator::CreateSignatureDefs(
const std::vector<SignatureDefData>& signature_defs) {
std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
for (const auto& signature_def_data : signature_defs) {
auto inputs = GetList(signature_def_data.inputs);
auto outputs = GetList(signature_def_data.outputs);
auto inputs_buf = builder_.CreateVector(inputs);
auto outputs_buf = builder_.CreateVector(outputs);
auto method_name_buf =
builder_.CreateString(signature_def_data.method_name);
auto signature_def_key_buf =
builder_.CreateString(signature_def_data.signature_def_key);
tflite::SignatureDefBuilder sig_def_builder(builder_);
sig_def_builder.add_inputs(inputs_buf);
sig_def_builder.add_outputs(outputs_buf);
sig_def_builder.add_method_name(method_name_buf);
sig_def_builder.add_key(signature_def_key_buf);
signature_defs_buffer.push_back(sig_def_builder.Finish());
}
return builder_.CreateVector(signature_defs_buffer);
}
bool UpdateEntryFunction(ModuleOp module) { bool UpdateEntryFunction(ModuleOp module) {
if (module.lookupSymbol<FuncOp>("main") != nullptr) { if (module.lookupSymbol<FuncOp>("main") != nullptr) {
// We already have an entry function. // We already have an entry function.
@ -1312,11 +1490,12 @@ bool UpdateEntryFunction(ModuleOp module) {
Optional<std::string> Translator::Translate( Optional<std::string> Translator::Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { bool emit_custom_ops, const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
if (!UpdateEntryFunction(module)) return llvm::None; if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None; if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, op_or_arg_name_mapper); emit_custom_ops, tags, op_or_arg_name_mapper);
return translator.TranslateInternal(); return translator.TranslateInternal();
} }
@ -1392,10 +1571,17 @@ Optional<std::string> Translator::TranslateInternal() {
auto metadata = CreateMetadataVector(); auto metadata = CreateMetadataVector();
if (!metadata) return llvm::None; if (!metadata) return llvm::None;
auto model = tflite::CreateModel( // Build SignatureDef
builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), // We only have 1 entry point 'main' function, so build only 1 signature def.
builder_.CreateVector(subgraphs), description, auto main_fn_signature_def = BuildSignaturedef(
builder_.CreateVector(buffers_), metadata_buffer, *metadata); main_fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin());
auto signature_defs = CreateSignatureDefs(main_fn_signature_def);
auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
builder_.CreateVector(opcodes_),
builder_.CreateVector(subgraphs),
description, builder_.CreateVector(buffers_),
metadata_buffer, *metadata, *signature_defs);
tflite::FinishModelBuffer(builder_, model); tflite::FinishModelBuffer(builder_, model);
tflite::UpdateOpVersion(builder_.GetBufferPointer()); tflite::UpdateOpVersion(builder_.GetBufferPointer());
tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
@ -1519,12 +1705,10 @@ bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer, ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
OpOrArgNameMapper* op_or_arg_name_mapper) { OpOrArgNameMapper* op_or_arg_name_mapper) {
auto maybe_translated = return MlirToFlatBufferTranslateFunction(
Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_custom_ops, op_or_arg_name_mapper); emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
if (!maybe_translated) return true; op_or_arg_name_mapper);
*serialized_flatbuffer = std::move(*maybe_translated);
return false;
} }
bool tflite::MlirToFlatBufferTranslateFunction( bool tflite::MlirToFlatBufferTranslateFunction(
@ -1534,5 +1718,30 @@ bool tflite::MlirToFlatBufferTranslateFunction(
OpOrArgLocNameMapper op_or_arg_name_mapper; OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction( return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops, module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags) {
OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, saved_model_tags,
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
auto maybe_translated = Translator::Translate(
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
saved_model_tags, op_or_arg_name_mapper);
if (!maybe_translated) return true;
*serialized_flatbuffer = std::move(*maybe_translated);
return false;
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
#include <string> #include <string>
#include <unordered_set>
#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
@ -33,11 +34,24 @@ bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
bool emit_select_tf_ops, bool emit_select_tf_ops,
bool emit_custom_ops); bool emit_custom_ops);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags);
// Same as the above but with a custom op name mapper. // Same as the above but with a custom op name mapper.
bool MlirToFlatBufferTranslateFunction( bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer, mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_

View File

@ -876,6 +876,30 @@ def TFL_CosOp: TFL_Op<"cos", [
let hasFolder = 1; let hasFolder = 1;
} }
def TFL_CumsumOp: TFL_Op<"cumsum", [
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoQuantizableResult,
TFL_OperandHasRank<1, 0>]> {
let summary = "Cumsum operator";
let description = [{
Compute the cumulative sum of the tensor x along axis.
}];
let arguments = (
ins TFL_TensorOf<[F32, I32, I64]>:$input,
TFL_I32Tensor:$axis,
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
DefaultValuedAttr<BoolAttr, "false">:$reverse
);
let results = (outs TFL_TensorOf<[F32, I32, I64]>:$output);
let hasOptions = 1;
}
def TFL_DepthwiseConv2DOp : def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = ( let arguments = (

View File

@ -90,9 +90,10 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true; pass_config.lower_tensor_list_ops = true;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), return internal::ConvertMLIRToTFLiteFlatBuffer(
pass_config, result, toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
/*session=*/llvm::None); result,
/*session=*/llvm::None);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -177,7 +177,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// TODO(b/153507667): Pass the session object when importing logic is removed. // TODO(b/153507667): Pass the session object when importing logic is removed.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer( auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result, toco_flags, std::move(module), pass_config, tags, result,
/*session=*/llvm::None); /*session=*/llvm::None);
return status; return status;
} }

View File

@ -273,7 +273,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
Status ConvertMLIRToTFLiteFlatBuffer( Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config, string* result, const mlir::TFL::PassConfig& pass_config,
const std::unordered_set<std::string>& saved_model_tags, string* result,
llvm::Optional<tensorflow::Session*> session) { llvm::Optional<tensorflow::Session*> session) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
@ -297,8 +298,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
auto status = ConvertTFExecutorToTFLOrFlatbuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result, emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs,
&pm); saved_model_tags, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) { if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile( TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag. // rename once we enable the new converter feature flag.

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
#include <ostream> #include <ostream>
#include <unordered_set>
#include <utility> #include <utility>
#include "llvm/ADT/Optional.h" #include "llvm/ADT/Optional.h"
@ -48,7 +49,8 @@ Status PopulateQuantizationSpecs(
// This will also run relevant passes as well. // This will also run relevant passes as well.
Status ConvertMLIRToTFLiteFlatBuffer( Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config, string* result, const mlir::TFL::PassConfig& pass_config,
const std::unordered_set<std::string>& saved_model_tags, string* result,
llvm::Optional<tensorflow::Session*> session); llvm::Optional<tensorflow::Session*> session);
// Give a warning for any unused flags that have been specified. // Give a warning for any unused flags that have been specified.

View File

@ -96,5 +96,6 @@ versions {
# CHECK-NEXT: metadata: [ { # CHECK-NEXT: metadata: [ {
# CHECK-NEXT: name: "min_runtime_version", # CHECK-NEXT: name: "min_runtime_version",
# CHECK-NEXT: buffer: 4 # CHECK-NEXT: buffer: 4
# CHECK-NEXT: } ] # CHECK-NEXT: } ],
# CHECK-NEXT: signature_defs: [ ]
# CHECK-NEXT: } # CHECK-NEXT: }

View File

@ -54,6 +54,7 @@ tf_native_cc_binary(
deps = [ deps = [
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_utils",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
], ],
@ -70,6 +71,7 @@ tf_native_cc_binary(
deps = [ deps = [
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_utils",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
], ],

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/schema/schema_utils.h"
using llvm::Optional; using llvm::Optional;
using llvm::cl::opt; using llvm::cl::opt;
@ -95,7 +96,8 @@ Optional<std::unique_ptr<tflite::ModelT>> RemoveConstantOpInReshape(
// Find the reshape ops and make it single operand. // Find the reshape ops and make it single operand.
for (auto& sub_graph : model->subgraphs) { for (auto& sub_graph : model->subgraphs) {
for (auto& op : sub_graph->operators) { for (auto& op : sub_graph->operators) {
if (model->operator_codes[op->opcode_index]->builtin_code == if (tflite::GetBuiltinCode(
model->operator_codes[op->opcode_index].get()) ==
tflite::BuiltinOperator_RESHAPE) { tflite::BuiltinOperator_RESHAPE) {
auto& output_tensor = sub_graph->tensors[op->outputs[0]]; auto& output_tensor = sub_graph->tensors[op->outputs[0]];
auto shape = output_tensor->shape; auto shape = output_tensor->shape;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/schema/schema_utils.h"
using llvm::Optional; using llvm::Optional;
using llvm::cl::opt; using llvm::cl::opt;
@ -114,7 +115,8 @@ Optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
// Find the tensors and inject the min and max to the input and output // Find the tensors and inject the min and max to the input and output
for (auto& sub_graph : model->subgraphs) { for (auto& sub_graph : model->subgraphs) {
for (auto& op : sub_graph->operators) { for (auto& op : sub_graph->operators) {
if (model->operator_codes[op->opcode_index]->builtin_code == if (tflite::GetBuiltinCode(
model->operator_codes[op->opcode_index].get()) ==
tflite::BuiltinOperator_FULLY_CONNECTED) { tflite::BuiltinOperator_FULLY_CONNECTED) {
// inject min/max to the input and output tensors // inject min/max to the input and output tensors
auto& input_tensor = sub_graph->tensors[op->inputs[0]]; auto& input_tensor = sub_graph->tensors[op->inputs[0]];

View File

@ -3442,8 +3442,8 @@ func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "va
%0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64> %0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64>
%1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64> %1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64>
%2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64> %2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64>
%3 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> %3 = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi64>} : () -> tensor<2xi64>
%4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<1xi64>) -> tensor<?x10xf64> %4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<2xi64>) -> tensor<?x10xf64>
return %4 : tensor<?x10xf64> return %4 : tensor<?x10xf64>
} }

View File

@ -1361,7 +1361,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
// CHECK-LABEL: conv2d_backprop_input // CHECK-LABEL: conv2d_backprop_input
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[CST_0:.*]] = constant unit // CHECK: %[[CST_0:.*]] = constant unit
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
@ -1587,10 +1588,31 @@ func @tranpose_int64_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
// CHECK: "tfl.transpose" // CHECK: "tfl.transpose"
} }
func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> { func @tranpose_arg32(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> {
%0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> %0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %0 : tensor<3x2xf32> return %0 : tensor<3x2xf32>
// CHECK-LABEL: tranpose_arg // CHECK-LABEL: tranpose_arg32
// CHECK: "tfl.transpose" // CHECK: "tfl.transpose"
} }
func @tranpose_arg64(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi64>) -> tensor<3x2xf32> {
%0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
return %0 : tensor<3x2xf32>
// CHECK-LABEL: tranpose_arg64
// CHECK: "tfl.transpose"
}
func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
// CHECK-LABEL: cumsum
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
}
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
// CHECK-LABEL: cumsum_invalid
// CHECK-NOT: "tfl.cumsum"
}

View File

@ -116,6 +116,7 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 10 // CHECK-NEXT: buffer: 10
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
^bb0(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>): ^bb0(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>):

View File

@ -100,6 +100,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6 // CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

View File

@ -91,6 +91,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6 // CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const") %0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const")

View File

@ -93,6 +93,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6 // CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const") %0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const")

View File

@ -97,6 +97,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6 // CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

View File

@ -54,6 +54,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3 // CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: } // CHECK-NEXT: }
// IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32} // IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32}

View File

@ -47,6 +47,7 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3 // CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: } // CHECK-NEXT: }
%0 = "tf.AddV2"(%arg0, %arg0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> %0 = "tf.AddV2"(%arg0, %arg0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>

View File

@ -60,6 +60,7 @@ func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4 // CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add") %0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add")

View File

@ -60,6 +60,7 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4 // CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add") %0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add")

View File

@ -99,6 +99,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6 // CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

View File

@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5 // CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%cst = constant unit %cst = constant unit

View File

@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5 // CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:} // CHECK-NEXT:}
%cst = constant unit %cst = constant unit

View File

@ -166,6 +166,7 @@
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 11 // CHECK-NEXT: buffer: 11
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: } // CHECK-NEXT: }
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {

View File

@ -87,6 +87,7 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6 // CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-EMPTY: // CHECK-EMPTY:

View File

@ -258,6 +258,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 26 // CHECK-NEXT: buffer: 26
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-EMPTY: // CHECK-EMPTY:

View File

@ -320,5 +320,6 @@ func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 23 // CHECK-NEXT: buffer: 23
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: } // CHECK-NEXT: }
} }

View File

@ -140,6 +140,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 8 // CHECK-NEXT: buffer: 8
// CHECK-NEXT: } ] // CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: } // CHECK-NEXT: }
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

Some files were not shown because too many files have changed in this diff Show More