Merge branch 'master' into FloorMod-pr2

This commit is contained in:
ddavis-2015 2020-12-18 17:21:02 -08:00
commit d04c4c0504
1603 changed files with 50520 additions and 23888 deletions

View File

@ -145,10 +145,6 @@ build:monolithic --define framework_shared_object=false
# opts in to modular op registration support by default.
build --define framework_shared_object=true
# Flags for open source build, always set to be true.
build --define open_source_build=true
test --define open_source_build=true
# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain
build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain

View File

@ -3,6 +3,8 @@
/tensorflow/c/eager @qqfish @kkimdev
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
/tenosrflow/core/debug @caisq
/tensorflow/core/kernels/mkl/ @penpornk
/tensorflow/core/kernels/sparse/ @penpornk
/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mihaimaruseac
/tensorflow/lite/experimental/micro @petewarden @advaitjain

View File

@ -132,8 +132,8 @@ Build Type
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux aarch64 CPU** Nightly (Linaro)<br> Python 3.8 | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | [Nightly](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
**Linux aarch64 CPU** Stable Release (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | Release [1.x & 2.x](http://snapshots.linaro.org/hpc/python/tensorflow/latest/)
**Linux aarch64 CPU** Nightly (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow-nightly)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow-nightly/) | [Nightly](http://snapshots.linaro.org/ldcg/python/tensorflow-nightly/latest/)
**Linux aarch64 CPU** Stable Release (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow/) | Release [1.x & 2.x](http://snapshots.linaro.org/ldcg/python/tensorflow/latest/)
**Linux aarch64 CPU** Nightly (OpenLab)<br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
**Linux aarch64 CPU** Stable Release (OpenLab) | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)

View File

@ -32,6 +32,8 @@
* `tf.keras`:
* Improvements to Keras preprocessing layers:
* Discretization combiner implemented, with additional arg `epsilon`.
* Improvements to model saving/loading:
* `model.load_weights` now accepts paths to saved models.
* `tf.data`:
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
@ -58,9 +60,11 @@
directly.
* 16 bits quantization
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
* Added support for saved model's session initializer through
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* Added dynamic range quantization support for the BatchMatMul op.
* Added DEPTH_TO_SPACE support in Post training quantization.
* Added dynamic range quantization support for the BatchMatMul op.
* Both symmetric and asymmetric quantized input tensor are supported.
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
only supports float32 input.
* TFLite Supports SingatureDef:
@ -95,6 +99,11 @@
value of `is_dynamic_op` is not True. We didn't use the value for
`max_batch_size` for building TensorRT engines.
* Issue a warning when function get_tensorrt_rewriter_config is used.
* Other:
* Add new enum value `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` to
`tf.config.experimental.mlir_bridge_rollout` to enable a \"safe\" mode.
This runs the MLIR bridge only when an analysis of the graph only when
an analysis of the graph determines that it is safe to run.
## Thanks to our Contributors
@ -104,459 +113,448 @@ This release contains contributions from many people at Google, as well as:
# Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
## Major Features and Improvements
* `tf.distribute` introduces experimental support for asynchronous training of
models via the [`tf.distribute.experimental.ParameterServerStrategy`]
(https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/ParameterServerStrategy)
API. Please see the [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training)
to learn more.
* [`MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy)
is now a stable API and is no longer considered experimental. Some of the
major improvements involve handling peer failure and many bug fixes. Please
check out the detailed tutorial on [Multi-worker training with Keras]
(https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
* Introduces experimental support for a new module named [`tf.experimental.numpy`]
(https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which is a
NumPy-compatible API for writing TF programs. See the [detailed guide]
(https://www.tensorflow.org/guide/tf_numpy) to learn more. Additional details below.
* Adds Support for
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
on Ampere based GPUs. TensorFloat-32, or TF32 for short, is a math mode for
NVIDIA Ampere based GPUs and is enabled by default.
* A major refactoring of the internals of the Keras Functional API has been
completed, that should improve the reliability, stability, and performance of
constructing Functional models.
* Keras mixed precision API [`tf.keras.mixed_precision`]
(https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision?version=nightly)
is no longer experimental and allows the use of 16-bit floating point formats
during training, improving performance by up to 3x on GPUs and 60% on TPUs.
Please see below for additional details.
* TensorFlow Profiler now supports profiling `MultiWorkerMirroredStrategy` and
tracing multiple workers using the [sampling mode API]
(https://www.tensorflow.org/guide/profiler#profiling_apis).
* TFLite Profiler for Android is available. See the detailed [guide]
(https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android)
to learn more.
* TensorFlow pip packages are now built with CUDA11 and cuDNN 8.0.2.
## Breaking Changes
* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
matmuls and convolutions, due to the use of
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
* TF Core:
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
matmuls and convolutions, due to the use of [TensorFloat-32]
(https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
bits of precision. This is unlikely to cause issues in practice for deep
learning models. In some cases, TensorFloat-32 is also used for complex64 ops.
TensorFloat-32 can be disabled by running
`config.experimental.enable_tensor_float_32_execution(False)`. The "Major
Features and Improvements" section has more details.
* The byte layout for string tensors across the C-API has been updated to match
bits of precision. This is unlikely to cause issues in practice for deep learning
models. In some cases, TensorFloat-32 is also used for complex64 ops.
TensorFloat-32 can be disabled by running `tf.config.experimental.enable_tensor_float_32_execution(False)`.
* The byte layout for string tensors across the C-API has been updated to match
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
`TF_StringEncodedSize` are no longer relevant and have been removed; see
core/platform/ctstring.h for string access/modification in C.
* Removed `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2.
* `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are
now hidden. These modules are not part of TensorFlow public API.
* A major refactoring of the internals of the Keras Functional API may affect code that is relying on certain internal details:
* Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`.
* Code that is overly dependent on the exact names attached to symbolic tensors (e.g. assumes there will be ":0" at the end of the inputs, treats names as unique identifiers instead of using `tensor.ref()`, etc.)
* Code that uses `get_concrete_function` to trace Keras symbolic inputs directly should switch to building matching `tf.TensorSpec`s directly and tracing the `TensorSpec` objects.
* Code that relies on the exact number and names of the op layers that TensorFlow operations were converted into. These may have changed.
* Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy.
* Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values.
* Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now.
* Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead.
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
* Start enforcing input shape assumptions when calling Functional API Keras
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and `TF_StringEncodedSize`
are no longer relevant and have been removed; see `core/platform/ctstring.h` for
string access/modification in C.
* `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are
now hidden. These modules are not part of TensorFlow public API.
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
well defined for complex types.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them, but this
flag will eventually be removed in subsequent releases.
* `tf.keras`:
* The `steps_per_execution` argument in `model.compile()` is no longer experimental;
if you were passing `experimental_steps_per_execution`, rename it to
`steps_per_execution` in your code. This argument controls the number of batches
to run during each `tf.function` call when calling `model.fit()`. Running multiple
batches inside a single `tf.function` call can greatly improve performance on
TPUs or small models with a large Python overhead.
* A **major refactoring** of the internals of the Keras Functional API may affect code that
is relying on certain internal details:
* Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when
checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`.
* Code that is overly dependent on the exact names attached to symbolic tensors
(e.g. assumes there will be ":0" at the end of the inputs, treats names as
unique identifiers instead of using `tensor.ref()`, etc.) may break.
* Code that uses full path for `get_concrete_function` to trace Keras symbolic
inputs directly should switch to building matching `tf.TensorSpec`s directly and
tracing the `TensorSpec` objects.
* Code that relies on the exact number and names of the op layers that TensorFlow
operations were converted into may have changed.
* Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers
and happens to work before TF 2.4. These will explicitly be unsupported now.
Converting these ops to Functional API op layers was unreliable before TF 2.4,
and prone to erroring incomprehensibly or being silently buggy.
* Code that directly asserts on a Keras symbolic value in cases where ops
like `tf.rank` used to return a static or symbolic value depending on if the
input had a fully static shape or not. Now these ops always return symbolic values.
* Code already susceptible to leaking tensors outside of graphs becomes slightly
more likely to do so now.
* Code that tries directly getting gradients with respect to symbolic Keras
inputs/outputs. Use `GradientTape` on the actual Tensors passed to the already-constructed
model instead.
* Code that requires very tricky shape manipulation via converted op layers
in order to work, where the Keras symbolic shape inference proves insufficient.
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes
layers only ever have one positional argument. This assumption doesn't hold
true before TF 2.4 either, but is more likely to cause issues now.
* Code that manually enters `keras.backend.get_graph()` before building a
functional model is no longer needed.
* Start enforcing input shape assumptions when calling Functional API Keras
models. This may potentially break some users, in case there is a mismatch
between the shape used when creating `Input` objects in a Functional model,
and the shape of the data passed to that model. You can fix this mismatch by
either calling the model with correctly-shaped data, or by relaxing `Input`
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
well defined for complex types.
* `tf.data.experimental.service.DispatchServer` now takes a config tuple
either calling the model with correctly-shaped data, or by relaxing `Input` shape
assumptions (note that you can pass shapes with `None` entries for axes that
are meant to be dynamic). You can also disable the input checking entirely by
setting `model.input_spec = None`.
* Several changes have been made to `tf.keras.mixed_precision.experimental`.
Note that it is now recommended to use the non-experimental
`tf.keras.mixed_precision` API.
* `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
dtype it will be casted to.
* When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
float16 or bfloat16 tensor instead of a float32 tensor.
* The property `tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale`
is now a tensor, not a `LossScale` object. This means to get a loss scale
of a `LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale`instead of `opt.loss_scale()`.
* The property `should_cast_variables` has been removed from `tf.keras.mixed_precision.experimental.Policy`
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to `tf.keras.mixed_precision.experimental.LossScaleOptimizer`,
the `DynamicLossScale`'s multiplier must be 2.
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of being reused.
This means modifying the weights of the `DynamicLossScale` will no longer affect the weights of the LossScaleOptimizer, and vice versa.
* The global policy can no longer be set to a non-floating point policy in `tf.keras.mixed_precision.experimental.set_policy`
* In `Layer.call`, `AutoCastVariable`s will no longer be casted within
`MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a thread local
variable is used to determine whether `AutoCastVariable`s are casted, and those
two functions run with a different thread. Note this only applies if one of
these two functions is called within `Layer.call`; if one of those two functions calls `Layer.call`, `AutoCastVariable`s will still be casted.
* `tf.data`:
* `tf.data.experimental.service.DispatchServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.DispatchServer(dispatcher_config)`.
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.WorkerServer(worker_config)`.
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
updates the gradient definition for quantization which is outside the range
to be 0. To simulate the V1 the behavior of
tf.quantization.quantize_and_dequantize(...) use
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
use `tf.data.Dataset.from_tensor_slices` instead.
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
`tf.distribute.StrategyExtended.batch_reduce_to`,
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
`tf.distribute.experimental.CollectiveHints` is renamed
`tf.distribute.experimental.CommunicationOptions`.
`tf.distribute.experimental.CollectiveCommunication` is renamed
`tf.distribute.experimental.CommunicationImplementation`.
* `tf.keras.mixed_precision.experimental`:
* `AutoCastVariable.dtype` now refers to the actual variable dtype, not the
dtype it will be casted to.
* When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a
float16 or bfloat16 tensor instead of a float32 tensor.
* The property
`tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now
a tensor, not a `LossScale` object. This means to get a loss scale of a
`LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead
of `opt.loss_scale()`.
* The property `should_cast_variables` has been removed from
`tf.keras.mixed_precision.experimental.Policy`
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the
`DynamicLossScale`'s multiplier must be 2.
* When passing a `tf.mixed_precision.experimental.DynamicLossScale` to
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of
the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of
being reused. This means modifying the weights of the `DynamicLossScale`
will no longer affect the weights of the LossScaleOptimizer, and vice versa.
* The global policy can no longer be set to a non-floating point policy in
`tf.keras.mixed_precision.experimental.set_policy`
* In `Layer.call`, `AutoCastVariable`s will no longer be casted within
`MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a
thread local variable is used to determine whether `AutoCastVariable`s are
casted, and those two functions run with a different thread. Note this only
applies if one of these two functions is called within `Layer.call`; if one
of those two functions calls `Layer.call`, `AutoCastVariable`s will still be
casted.
## Known Caveats
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES). E.G. ADDING A NEW DEPENDENCY, BUMPING A DEPENDENCY NUMBER, LACK OF SUPPORT ON SOME PLATFORM, ETC>
## Major Features and Improvements
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
* Support for
[TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a
math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as
matrix multiplications and convolutions, to run much faster on Ampere GPUs but
with reduced precision. This reduced precision has not been found to effect
convergence quality of deep learning models in practice. TensorFloat-32 is
enabled by default, but can be disabled with
`tf.config.experimental.enable_tensor_float_32_execution`.
* `tf.data.experimental.service.WorkerServer` now takes a config tuple instead
of individual arguments. Usages should be updated to `tf.data.experimental.service.WorkerServer(worker_config)`.
* `tf.distribute`:
* `MultiWorkerMirroredStrategy` is graduated out of experimental.
* Peer failure will no longer cause the cluster to hang.
* Major issues with saving are fixed.
* See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
* The `tf.keras.mixed_precision` API has been made non-experimental. The major
changes to the new non-experimental API are:
* `tf.keras.mixed_precision.Policy` no longer takes in a
`tf.mixed_precision.experimental.LossScale` in the constructor, and no
longer has a `LossScale` associated with it. Instead, `Model.compile` will
automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic
loss scaling if `Policy.name` is "mixed_float16".
* `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in
different arguments. In particular, it no longer takes in a `LossScale`, and
there is no longer a `LossScale` associated with the `LossScaleOptimizer`.
Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss
scaling. See the documentation of
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on
the differences between the experimental `LossScaleOptimizer` and the new
non-experimental `LossScaleOptimizer`.
* `tf.mixed_precision.experimental.LossScale` and its subclasses are
deprecated, as all of its functionality now exists within
`tf.keras.mixed_precision.LossScaleOptimizer`
* Removes `tf.distribute.Strategy.experimental_make_numpy_dataset`. Please use
`tf.data.Dataset.from_tensor_slices` instead.
* Renames `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
`tf.distribute.StrategyExtended.batch_reduce_to`, `tf.distribute.ReplicaContext.all_reduce`
to `options`.
* Renames `tf.distribute.experimental.CollectiveHints` to `tf.distribute.experimental.CommunicationOptions`.
* Renames `tf.distribute.experimental.CollectiveCommunication` to `tf.distribute.experimental.CommunicationImplementation`.
* Renames `tf.distribute.Strategy.experimental_distribute_datasets_from_function` to `distribute_datasets_from_function` as it is no longer experimental.
* Removes `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2.
* `tf.lite`:
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which updates the gradient definition for quantization which is outside the range
to be 0. To simulate the V1 the behavior of `tf.quantization.quantize_and_dequantize(...)` use
`tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...)`.
* Building TensorFlow:
* Windows platform builds: TensorFlow on Windows under MSVC is now built with
`--copt=/experimental:preprocessor --host_copt=/experimental:preprocessor`
(see `.bazelrc` for more details). Builds including TensorFlow may fail with
unexpected syntax errors if these flags are absent. See also
[this thread on SIG Build](https://groups.google.com/a/tensorflow.org/g/build/c/LbAw8RILvTg/m/ttnuhYU2BgAJ).
## Known Caveats
* `tf.keras.mixed_precision`
* When using mixed precision, calling `RMSprop.apply_gradients` or
`Nadam.apply_gradients` outside a `tf.function` does not work and will raise
the AttributeError "Tensor.op is meaningless when eager execution is enabled".
See this [issue](https://github.com/tensorflow/tensorflow/issues/45536) for details and a workaround.
## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* Security:
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
`SparseCountSparseOutput` operations
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
* Fixes an integer truncation vulnerability in code using the work sharder
API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from
`tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in
TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Fixes several vulnerabilities in TFLite implementation of segment sum
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
* Fixes a segfault in `tf.quantization.quantize_and_dequantize`
([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
* Fixes an undefined behavior float cast causing a crash
([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
* TF Core:
* `tf.types.experimental.TensorLike` is a new `Union` type that can be
used as type annotation for variables representing a Tensor or a value
that can be converted to Tensor by `tf.convert_to_tensor`.
* Calling ops with a python constants or numpy values is now consistent
with tf.convert_to_tensor behavior. This avoids operations like
tf.reshape truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of
`SparseTensor` arguments.
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
`__xor__` and `__invert__` now support non-`bool` arguments and apply
the corresponding bitwise ops. `bool` arguments continue to be supported
and dispatch to logical ops. This brings them more in line with Python
and NumPy behavior.
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
with the same sparsity pattern, but with new provided values. It is
similar to the `with_values` function of `RaggedTensor`.
* Added `StatelessCase` op, and uses it if none of case branches has
stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory
usage of the device.
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
* `tf.data`:
* tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one
process to register a dataset with the tf.data service, and another
process to consume data from the dataset.
* Added support for dispatcher fault tolerance. To enable fault tolerance,
configure a `work_dir` when running your dispatcher server and set
`dispatcher_fault_tolerance=True`. The dispatcher will store its state
to `work_dir`, so that on restart it can continue from its previous
state after restart.
* Added support for sharing dataset graphs via shared filesystem instead
of over RPC. This reduces load on the dispatcher, improving performance
of distributing datasets. For this to work, the dispatcher's `work_dir`
must be accessible from workers. If the worker fails to read from the
`work_dir`, it falls back to using RPC for dataset graph transfer.
* Added support for a new "distributed_epoch" processing mode. This
processing mode distributes a dataset across all tf.data workers,
instead of having each worker process the full dataset. See
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
to learn more.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be
specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overridden.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
with a new `output_signature` argument, which allows `from_generator` to
produce any type describable by a `tf.TypeSpec`.
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
`tf.data.AUTOTUNE`.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a deterministic
version of `sample_distorted_bounding_box` op. Given the same seed,
these stateless functions/ops produce the same results independent of
how many times the function is called, and independent of global seed
settings.
* `tf.distribute`:
* (Experimental) Parameter server training:
* Replaced the existing
`tf.distribute.experimental.ParameterServerStrategy` symbol with
a new class that is for parameter server training in TF2. Usage with
the old symbol, usually with Estimator, should be replaced with
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
* Added `tf.distribute.experimental.coordinator.*` namespace,
including the main API `ClusterCoordinator` for coordinating the
training cluster, the related data structure `RemoteValue`
and `PerWorkerValue`.
* `tf.keras`:
* Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global
workspace graph, removing memory leaks especially when building many
models or very large models.
* Functional model construction should be ~8-10% faster on average.
* Functional models can now contain non-symbolic values in their call
inputs inside of the first positional argument.
* Several classes of TF ops that were not reliably converted to Keras
layers during functional API construction should now work, e.g.
`tf.image.ssim_multiscale`
* Error messages when Functional API construction goes wrong (and when
ops cannot be converted to Keras layers automatically) should be
clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
to match FTRL paper
(https://research.google.com/pubs/archive/41159.pdf).
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* The `steps_per_execution` argument in `compile()` is no longer
experimental; if you were passing `experimental_steps_per_execution`,
rename it to `steps_per_execution` in your code. This argument controls
the number of batches to run during each `tf.function` call when calling
`fit()`. Running multiple batches inside a single `tf.function` call can
greatly improve performance on TPUs or small models with a large Python
overhead.
* Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an
init arg.
* TextVectorization, StringLookup, and IntegerLookup can now accept a
vocabulary file via the `set_vocab_from_file` method.
* Normalization can now accept mean and variance values as init args.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
accepts a `return_attention_scores` argument. When set to
True, the layer returns the attention scores as an additional output
argument.
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
with the same implementation as their `tf.losses` equivalent.
* For Keras model, the individual call of `Model.evaluate` uses no cached
data for evaluation, while `Model.fit` uses cached data when
`validation_data` arg is provided for better performance.
* Added a `save_traces` argument to `model.save`/
`tf.keras.models.save_model` which determines whether the SavedModel
format stores the Keras model/layer call functions. The traced functions
allow Keras to revive custom models and layers without the original
class definition, but if this isn't required the tracing can be
disabled with the added option.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFlow loop, if
the values of these symbols at an iteration does not depend on the
previous iteration. These types of loops must run at least one
iteration, and will raise a runtime error otherwise.
* Variables contained in `tf.Module`s that are set as attributes of
custom Keras `Layer`s and `Model`s are now tracked in
the properties `layer.trainable_variables` and
`layer.non_trainable_variables`.
### TF Core:
* Introduces experimental support for a new module named [`tf.experimental.numpy`]
(https://www.tensorflow.org/api_docs/python/tf/experimental/numpy), which is a
NumPy-compatible API for writing TF programs. This module provides class
`ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable
`tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are
provided. Their inter-operation with TF facilities is seamless in most cases.
See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md)
for details of what operations are supported and what are the differences
from NumPy.
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
type annotation for variables representing a Tensor or a value
that can be converted to Tensor by `tf.convert_to_tensor`.
* Calling ops with a python constants or numpy values is now consistent with
tf.convert_to_tensor behavior. This avoids operations like
tf.reshape truncating inputs such as from int64 to int32.
* Adds `tf.sparse.map_values` to apply a function to the `.value`s of
`SparseTensor` arguments.
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` and `__invert__` now support non-`bool`
arguments and apply the corresponding bitwise ops. `bool` arguments continue
to be supported and dispatch to logical ops. This brings them more in line with
Python and NumPy behavior.
* Adds `tf.SparseTensor.with_values`. This returns a new SparseTensor with the same sparsity pattern, but with new provided values. It is
similar to the `with_values` function of `RaggedTensor`.
* Adds `StatelessCase` op, and uses it if none of case branches has stateful ops.
* Adds `tf.config.experimental.get_memory_usage` to return total memory usage of the device.
* Adds gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* Improve shape inference of nested function calls by supporting constant
folding across Arg nodes which makes more static values available to shape
inference functions.
* `tf.debugging`:
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (Fixes [#36268](https://github.com/tensorflow/tensorflow/issues/36268)).
* GPU
* Adds Support for [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)
on Ampere based GPUs.TensorFloat-32, or TF32 for short, is a math mode for
NVIDIA Ampere based GPUs which causes certain float32 ops, such as matrix
multiplications and convolutions, to run much faster on Ampere GPUs but with
reduced precision. This reduced precision has not been found to effect
convergence quality of deep learning models in practice. TensorFloat-32 is
enabled by default, but can be disabled with `tf.config.experimental.enable_tensor_float_32_execution`.
* `tf.math`:
* Adds `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
* `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* `tf.image`:
* Adds deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op `stateless_sample_distorted_bounding_box`
which is a deterministic version of `sample_distorted_bounding_box` op.
Given the same seed, these stateless functions/ops produce the same results
independent of how many times the function is called, and independent of global seed settings.
* Adds deterministic `tf.image.resize` backprop CUDA kernels for
`method=ResizeMethod.BILINEAR` (the default method). Enable by setting the environment
variable `TF_DETERMINISTIC_OPS` to `"true"` or `"1"`.
* `tf.print`:
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping.
* `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint`
object that is compatible with Keras `model.save_weights()` and
`model.load_weights`. The checkpoint is also compatible with the checkpoint
saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function will
automatically find the checkpoint in the SavedModel.
Example:
### `tf.data`:
* Adds new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one process to
register a dataset with the tf.data service, and another process to consume
data from the dataset.
* Adds support for dispatcher fault tolerance. To enable fault tolerance,
configure a `work_dir` when running your dispatcher server and set
`dispatcher_fault_tolerance=True`. The dispatcher will store its state to
`work_dir`, so that on restart it can continue from its previous state after restart.
* Adds support for sharing dataset graphs via shared filesystem instead of
over RPC. This reduces load on the dispatcher, improving performance
of distributing datasets. For this to work, the dispatcher's `work_dir`
must be accessible from workers. If the worker fails to read from the `work_dir`,
it falls back to using RPC for dataset graph transfer.
* Adds support for a new "distributed_epoch" processing mode.
This processing mode distributes a dataset across all tf.data workers,
instead of having each worker process the full dataset. See
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
to learn more.
* Adds optional `exclude_cols` parameter to CsvDataset. This parameter is the
complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the dataset
when it is safe to do so. The optimization can be disabled via the
`experimental_optimization.reorder_data_discarding_ops` dataset option.
* `tf.data.Options` were previously immutable and can now be overridden.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors with
a new `output_signature` argument, which allows `from_generator` to produce any
type describable by a `tf.TypeSpec`.
* `tf.data.experimental.AUTOTUNE` is now available in the core API as `tf.data.AUTOTUNE`.
```
for batch in data:
outputs = train_step(batch)
tf.print('final outputs', outputs)
```
### `tf.distribute`:
* Introduces experimental support for asynchronous training of models via
`tf.distribute.experimental.ParameterServerStrategy`:
* Replaces the existing `tf.distribute.experimental.ParameterServerStrategy`
symbol with a new class that is for parameter server training in TF2. Usage of
the old symbol, usually with Estimator API, should be **replaced** with
[`tf.compat.v1.distribute.experimental.ParameterServerStrategy`].
* Added `tf.distribute.experimental.coordinator.*` namespace, including the
main API `ClusterCoordinator` for coordinating the training cluster, the
related data structure `RemoteValue` and `PerWorkerValue`.
* `MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy)
is now a stable API and is no longer considered experimental. Some of the major
improvements involve handling peer failure and many bug fixes. Please check out
the detailed tutorial on [Multi-worer training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
* Adds `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather`
APIs to support gathering dense distributed values.
* Fixes various issues with saving a distributed model.
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info.
### `tf.keras`:
* Improvements from the Functional API refactoring:
* Functional model construction does not need to maintain a global workspace
graph, removing memory leaks especially when building many models or very large models.
* Functional model construction should be ~8-10% faster on average.
* Functional models can now contain non-symbolic values in their call inputs
inside of the first positional argument.
* Several classes of TF ops that were not reliably converted to Keras layers
during functional API construction should now work, e.g.`tf.image.ssim_multiscale`
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be
clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Adds `beta` hyperparameter to [FTRL](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl)
optimizer classes (Keras and others) to match [FTRL paper](https://research.google.com/pubs/archive/41159.pdf).
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for customization
of how gradients are aggregated across devices, as well as `gradients_transformers`
to allow for custom gradient transformations (such as gradient clipping).
* Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an init arg.
* Normalization can now accept mean and variance values as init args.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now accepts a `return_attention_scores` argument. When set to
True, the layer returns the attention scores as an additional output argument.
* Adds `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints with the
same implementation as their `tf.losses` equivalent.
* For Keras model, the individual call of `Model.evaluate` uses no cached data
for evaluation, while `Model.fit` uses cached data when `validation_data` arg
is provided for better performance.
* Adds a `save_traces` argument to `model.save`/ `tf.keras.models.save_model`
which determines whether the SavedModel format stores the Keras model/layer call
functions. The traced functions allow Keras to revive custom models and layers
without the original class definition, but if this isn't required the tracing
can be disabled with the added option.
* The `tf.keras.mixed_precision` API is now non-experimental.
The non-experimental API differs from the experimental API in several ways.
* `tf.keras.mixed_precision.Policy` no longer takes in a `tf.mixed_precision.
experimental.LossScale` in the constructor, and no longer has a `LossScale`
associated with it. Instead, `Model.compile` will automatically wrap the optimizer
with a `LossScaleOptimizer` using dynamic loss scaling if `Policy.name`
is "mixed_float16".
* `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in different
arguments. In particular, it no longer takes in a `LossScale`, and there is
no longer a `LossScale` associated with the `LossScaleOptimizer`. Instead,
`LossScaleOptimizer` directly implements fixed or dynamic loss scaling. See the
documentation of [`tf.keras.mixed_precision.experimental.LossScaleOptimizer`]
(https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/experimental/LossScaleOptimizer?version=nightly)
for details on the differences between the experimental `LossScaleOptimizer`
and the new non-experimental `LossScaleOptimizer`.
* `tf.mixed_precision.experimental.LossScale` and its subclasses are
deprecated, as all of its functionality now exists within `tf.keras.mixed_precision.LossScaleOptimizer`
* `tf.lite`:
### `tf.lite`:
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type`
for full integer quantized models. This allows users to modify the model input
and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting
to float type (`tf.float32`).
* NNAPI
* Adds NNAPI Delegation support for requantization use cases by converting
the operation into a dequantize-quantize pair.
* Removes deprecated `Interpreter.setUseNNAPI(boolean)` Java API. Use
`Interpreter.Options.setUseNNAPI` instead.
* Deprecates `Interpreter::UseNNAPI(bool)` C++ API. Use `NnApiDelegate()`
and related delegate configuration methods directly.
* Deprecates `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API.
Prefer controlling this via delegate options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16'
or `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* GPU
* GPU acceleration now supports quantized models by default
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first string to be joined is empty.
* Adds support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and
`inference_output_type` for full integer quantized models. This
allows users to modify the model input and output type to integer
types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
(`tf.float32`).
* TFLite Profiler for Android is available. See the detailed
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
* NNAPI
* Added NNAPI Delegation support for requantization use cases by
converting the operation into a dequantize-quantize pair.
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
* Use `Interpreter.Options.setUseNNAPI` instead.
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
* Prefer controlling this via delegate options, e.g.
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
* <ADD RELEASE NOTES HERE>
### `TensorRT`
* Issues a warning when the `session_config` parameter for the TF1 converter
is used or the `rewrite_config_template` field in the TF2 converter parameter
object is used.
* `tf.random`:
### TPU Enhancements:
* Adds support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
### XLA Support:
* xla.experimental.compile is deprecated, use `tf.function(experimental_compile=True)` instead.
* Adds `tf.function.experimental_get_compiler_ir` which returns compiler IR
(currently 'hlo' and 'optimized_hlo') for given input for given function.
* Math and Linear Algebra:
### Security:
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`,
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
* [CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
* [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
* [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
* [CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
* [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and `SparseCountSparseOutput` operations
* [CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
* [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
* [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
* [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
* [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
* [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)
* Fixes an integer truncation vulnerability in code using the work sharder API,
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`,
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode,
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`,
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation,
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in TFLite,
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite,
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
* [CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
* [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
* [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)
* Fixes several vulnerabilities in TFLite implementation of segment sum
* [CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
* [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
* [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)
* Fixes a segfault in `tf.quantization.quantize_and_dequantize`,
([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265))
* Fixes an undefined behavior float cast causing a crash,
([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266))
* Fixes a lack of validation in `tf.raw_ops.DataFormatVecPermute` and
`tf.raw_ops.DataFormatDimMap` which can cause uninitialized memory access,
read outside bounds of arrays, data corruption and segmentation faults
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a crash caused by writing to read only memory region
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268))
* Fixes a heap out of bounds access in filesystem globbing implementation
([CVE-2020-26269](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26269))
* Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`.
* TPU Enhancements:
* Added support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* Added `tf.function.experimental_get_compiler_ir` which returns compiler
IR (currently 'hlo' and 'optimized_hlo') for given input for given
function.
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a
`Checkpoint` object that is compatible with Keras `model.save_weights()`
and `model.load_weights`. The checkpoint is also compatible with the
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* `tf.debugging`:
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
* `tf.print`:
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping.
* `TensorRT`
* We now issue a warning when the `session_config` parameter for the TF1
converter is used or the `rewrite_config_template` field in the TF2
converter parameter object is used.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more
context.
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
rollout the new MLIR TPU bridge.
* Added `tf.experimental.register_filesystem_plugin` to load modular
filesystem plugins from Python
* <ADD RELEASE NOTES HERE>
### Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist" and
"denylist" where possible. Please see [this list](https://developers.google.com/style/word-list#blacklist) for more context.
* Adds `tf.config.experimental.mlir_bridge_rollout` which will help us rollout the new MLIR TPU bridge.
* Adds `tf.experimental.register_filesystem_plugin` to load modular filesystem plugins from Python
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
This release contains contributions from many people at Google as well as the following external contributors:
stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
8bitmp3, aaa.jq, Abhineet Choudhary, Abolfazl Shahbazi, acxz, Adam Hillier, Adrian Garcia Badaracco, Ag Ramesh, ahmedsabie, Alan Anderson, Alexander Grund, Alexandre Lissy, Alexey Ivanov, Amedeo Cavallo, anencore94, Aniket Kumar Singh, Anthony Platanios, Ashwin Phadke, Balint Cristian, Basit Ayantunde, bbbboom, Ben Barsdell, Benjamin Chetioui, Benjamin Peterson, bhack, Bhanu Prakash Bandaru Venkata, Biagio Montaruli, Brent M. Spell, bubblebooy, bzhao, cfRod, Cheng Chen, Cheng(Kit) Chen, Chris Tessum, Christian, chuanqiw, codeadmin_peritiae, COTASPAR, CuiYifeng, danielknobe, danielyou0230, dannyfriar, daria, DarrenZhang01, Denisa Roberts, dependabot[bot], Deven Desai, Dmitry Volodin, Dmitry Zakharov, drebain, Duncan Riach, Eduard Feicho, Ehsan Toosi, Elena Zhelezina, emlaprise2358, Eugene Kuznetsov, Evaderan-Lab, Evgeniy Polyakov, Fausto Morales, Felix Johnny, fo40225, Frederic Bastien, Fredrik Knutsson, fsx950223, Gaurav Singh, Gauri1 Deshpande, George Grzegorz Pawelczak, gerbauz, Gianluca Baratti, Giorgio Arena, Gmc2, Guozhong Zhuang, Hannes Achleitner, Harirai, HarisWang, Harsh188, hedgehog91, Hemal Mamtora, Hideto Ueno, Hugh Ku, Ian Beauregard, Ilya Persky, jacco, Jakub Beránek, Jan Jongboom, Javier Montalt Tordera, Jens Elofsson, Jerry Shih, jerryyin, jgehw, Jinjing Zhou, jma, jmsmdy, Johan Nordström, John Poole, Jonah Kohn, Jonathan Dekhtiar, jpodivin, Jung Daun, Kai Katsumata, Kaixi Hou, Kamil Rakoczy, Kaustubh Maske Patil, Kazuaki Ishizaki, Kedar Sovani, Koan-Sin Tan, Koki Ibukuro, Krzysztof Laskowski, Kushagra Sharma, Kushan Ahmadian, Lakshay Tokas, Leicong Li, levinxo, Lukas Geiger, Maderator, Mahmoud Abuzaina, Mao Yunfei, Marius Brehler, markf, Martin Hwasser, Martin Kubovčík, Matt Conley, Matthias, mazharul, mdfaijul, Michael137, MichelBr, Mikhail Startsev, Milan Straka, Ml-0, Myung-Hyun Kim, Måns Nilsson, Nathan Luehr, ngc92, nikochiko, Niranjan Hasabnis, nyagato_00, Oceania2018, Oleg Guba, Ongun Kanat, OscarVanL, Patrik Laurell, Paul Tanger, Peter Sobot, Phil Pearl, PlusPlusUltra, Poedator, Prasad Nikam, Rahul-Kamat, Rajeshwar Reddy T, redwrasse, Rickard, Robert Szczepanski, Rohan Lekhwani, Sam Holt, Sami Kama, Samuel Holt, Sandeep Giri, sboshin, Sean Settle, settle, Sharada Shiddibhavi, Shawn Presser, ShengYang1, Shi,Guangyong, Shuxiang Gao, Sicong Li, Sidong-Wei, Srihari Humbarwadi, Srinivasan Narayanamoorthy, Steenu Johnson, Steven Clarkson, stjohnso98, Tamas Bela Feher, Tamas Nyiri, Tarandeep Singh, Teng Lu, Thibaut Goetghebuer-Planchon, Tim Bradley, Tomasz Strejczek, Tongzhou Wang, Torsten Rudolf, Trent Lo, Ty Mick, Tzu-Wei Sung, Varghese, Jojimon, Vignesh Kothapalli, Vishakha Agrawal, Vividha, Vladimir Menshakov, Vladimir Silyaev, VoVAllen, Võ Văn Nghĩa, wondertx, xiaohong1031, Xiaoming (Jason) Cui, Xinan Jiang, Yair Ehrenwald, Yasir Modak, Yasuhiro Matsumoto, Yimei Sun, Yiwen Li, Yixing, Yoav Ramon, Yong Tang, Yong Wu, yuanbopeng, Yunmo Koo, Zhangqiang, Zhou Peng, ZhuBaohe, zilinzhu, zmx
# Release 2.3.1

View File

@ -3,7 +3,7 @@
# learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",
@ -401,13 +401,20 @@ config_setting(
define_values = {"using_cuda_clang": "true"},
)
# Flag to indicate open source build, .bazelrc always has it set to be true
# Config setting to use in select()s to distinguish open source build from
# google internal build on configurable attributes.
config_setting(
name = "oss",
define_values = {"open_source_build": "true"},
flag_values = {":oss_setting": "True"},
visibility = ["//visibility:public"],
)
# Fixed setting to indicate open source build.
bool_setting(
name = "oss_setting",
build_setting_default = True,
)
config_setting(
name = "using_cuda_clang_with_dynamic_build",
define_values = {
@ -416,12 +423,12 @@ config_setting(
},
)
config_setting(
selects.config_setting_group(
name = "build_oss_using_cuda_clang",
define_values = {
"using_cuda_clang": "true",
"open_source_build": "true",
},
match_all = [
":using_cuda_clang",
":oss",
],
)
# Setting to use when loading kernels dynamically
@ -447,12 +454,12 @@ config_setting(
},
)
config_setting(
selects.config_setting_group(
name = "build_oss_using_cuda_nvcc",
define_values = {
"using_cuda_nvcc": "true",
"open_source_build": "true",
},
match_all = [
":using_cuda_nvcc",
":oss",
],
)
config_setting(

View File

@ -116,7 +116,8 @@ from tensorflow.python.lib.io import file_io as _fi
# Get sitepackages directories for the python installation.
_site_packages_dirs = []
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
if _site.ENABLE_USER_SITE and _site.USER_SITE is not None:
_site_packages_dirs += [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
_site_packages_dirs += _site.getsitepackages()
@ -145,6 +146,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Add module aliases
if hasattr(_current_module, 'keras'):

View File

@ -155,6 +155,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.

View File

@ -78,7 +78,7 @@ cc_library(
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
"//tensorflow/python:__subpackages__",
],
)
@ -684,7 +684,10 @@ tf_cc_test(
name = "c_api_experimental_test",
size = "medium",
srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"],
data = [
"testdata/tf_record",
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
@ -704,6 +707,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -37,7 +37,9 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h"
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
} // namespace tensorflow
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable) {
opts->opts.validate_colocation_constraints = enable;
}
// Load a Pluggable Device library.
// On success, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is loaded
// for the first time.
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"PluggableDevice plugin functionality is not supported on mobile");
return nullptr;
#else
TF_Library* lib_handle = new TF_Library;
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<std::string, void*>* loaded_libs =
new std::unordered_map<std::string, void*>();
tensorflow::Env* env = tensorflow::Env::Default();
{
tensorflow::mutex_lock lock(mu);
auto it = loaded_libs->find(library_filename);
if (it != loaded_libs->end()) {
lib_handle->lib_handle = it->second;
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
}
return lib_handle;
}
#endif
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
delete lib_handle;
}

View File

@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable);
// Load the library specified by library_filename and register the pluggable
// device and related kernels present in that library. This function is not
// supported on embedded on mobile and embedded platforms and will fail if
// called.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, returns the newly created library handle and places OK in status.
// The caller owns the library handle.
//
// On failure, returns nullptr and places an error status in status.
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
const char* library_filename, TF_Status* status);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TF_DeleteTensor(tensor_1X6);
}
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path =
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
"tensorflow", "c", "experimental", "stream_executor", "test",
"test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
} // namespace
} // namespace tensorflow

View File

@ -213,7 +213,11 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
TF_DeleteFunction(tf_function);
return nullptr;
}
tf_function->graph_with_debug_info = &fn_body->graph;
for (const Node* n : fn_body->graph.nodes()) {
tf_function->stack_traces[n->name()] = n->GetStackTrace();
}
return tf_function;
}

View File

@ -157,9 +157,7 @@ struct TF_DeviceList {
struct TF_Function {
tensorflow::FunctionDef fdef;
// Graph with nodes with debug stack traces.
const tensorflow::Graph* graph_with_debug_info = nullptr;
tensorflow::StackTracesMap stack_traces;
};
struct TF_ApiDefMap {

View File

@ -248,6 +248,7 @@ cc_library(
":c_api_unified_internal",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
@ -388,6 +389,7 @@ cc_library(
cc_library(
name = "gradient_checker",
testonly = 1,
srcs = [
"gradient_checker.cc",
],
@ -399,27 +401,11 @@ cc_library(
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
":unified_api_testutil",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test(
@ -430,33 +416,19 @@ tf_cuda_cc_test(
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
tags = tf_cuda_tests_tags() + [
"no_cuda_asan", # b/175330074
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradient_checker",
":gradients_internal",
":gradients_util",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
":unified_api_testutil",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/experimental/ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:tensor_float_32_utils",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
@ -503,6 +475,7 @@ tf_cuda_cc_test(
cc_library(
name = "abstract_tensor_handle",
srcs = ["abstract_tensor_handle.cc"],
hdrs = ["abstract_tensor_handle.h"],
visibility = [
"//tensorflow:internal",

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/abstract_tensor_handle.h"
namespace tensorflow {
std::string AbstractTensorHandle::DebugString() const {
PartialTensorShape shape;
Status s = Shape(&shape);
std::string shape_string;
if (!s.ok()) {
shape_string = "<error computing shape>";
} else {
shape_string = shape.DebugString();
}
return absl::StrCat("TensorHandle(shape=", shape_string,
", dtype=", DataType_Name(DataType()), ")");
}
} // namespace tensorflow

View File

@ -27,7 +27,7 @@ namespace tensorflow {
// execution mode.
class AbstractTensorHandle : public core::RefCounted {
protected:
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
virtual ~AbstractTensorHandle() {}
@ -38,6 +38,10 @@ class AbstractTensorHandle : public core::RefCounted {
virtual tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const = 0;
// The default debug string includes a shape and dtype. Implementations are
// free to override it with something more informative.
virtual std::string DebugString() const;
AbstractTensorHandleKind getKind() const { return kind_; }
private:

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -76,6 +77,15 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
// Annotate eager runtime construction context to the given `function_def` as
// an attribute.
void AnnotateEagerRuntimeConstructionContext(
tensorflow::FunctionDef& function_def) {
tensorflow::AttrValue value;
SetAttrValue("kEagerRuntime", &value);
(*function_def.mutable_attr())["_construction_context"] = value;
}
} // namespace
extern "C" {
@ -744,13 +754,16 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
AnnotateEagerRuntimeConstructionContext(function_def);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo(
function->fdef, function->graph_with_debug_info);
AnnotateEagerRuntimeConstructionContext(function->fdef);
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
function->fdef, function->stack_traces);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,

View File

@ -20,10 +20,11 @@ namespace {
void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false) {
bool remote_func_outputs = false,
bool has_packed_input = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
heavy_load_on_streaming_rpc,
remote_func_outputs);
remote_func_outputs, has_packed_input);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
@ -60,5 +61,14 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesRemoteAsyncPackedInputFuncOrdering) {
// A remote input (packed) may be not ready when we start running a function.
// Test that the function execution should wait until the remote input is
// ready.
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/true,
/*remote_func_outputs*/ true,
/*has_packed_input=*/true);
}
} // namespace

View File

@ -68,7 +68,9 @@ string MatMulFunction(const string& matmul_device) {
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs) {
bool remote_func_outputs,
bool has_packed_input) {
CHECK(!has_packed_input || func);
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -123,6 +125,15 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* packed_handle = nullptr;
if (has_packed_input) {
int num_replicas = 1;
std::vector<TFE_TensorHandle*> packed_handles = {h1_task2};
packed_handle = TFE_CreatePackedTensorHandle(ctx, packed_handles.data(),
&num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_Op* matmul = nullptr;
if (func) {
const string matmul_device = remote_func_outputs ? task2_name : "";
@ -135,7 +146,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
TFE_OpAddInput(matmul, has_packed_input ? packed_handle : h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
@ -194,6 +205,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
if (packed_handle) {
TFE_DeleteTensorHandle(packed_handle);
}
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {

View File

@ -21,6 +21,7 @@ limitations under the License.
// is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false);
bool remote_func_outputs = false,
bool has_packed_input = false);
#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_

View File

@ -18,18 +18,8 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
@ -45,16 +35,6 @@ void Range(vector<int>* data, int start, int end, int step = 1) {
}
}
// Returns AbstractTensorHandlePtr containing [0, ..., n-1].
AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) {
vector<int> vals(n);
int64_t vals_shape[] = {n};
Range(&vals, 0, n);
AbstractTensorHandlePtr r =
GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1);
return r;
}
// Fills out_dims with the dimensions of the given tensor.
void GetDims(const TF_Tensor* t, int64_t* out_dims) {
int num_dims = TF_NumDims(t);
@ -69,39 +49,41 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
bool use_function) {
GradientRegistry registry;
std::vector<AbstractTensorHandle*> model_outputs(1);
// Run the model.
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
absl::MakeSpan(model_outputs), use_function,
registry));
AbstractTensorHandle* model_out = model_outputs[0];
absl::MakeSpan(model_outputs), use_function));
AbstractTensorHandlePtr model_out(model_outputs[0]);
TF_Tensor* model_out_tensor;
TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor));
TF_RETURN_IF_ERROR(GetValue(model_out.get(), &model_out_tensor));
int num_dims_out = TF_NumDims(model_out_tensor);
TF_DeleteTensor(model_out_tensor);
// If the output is a scalar, then return the scalar output
if (num_dims_out == 0) {
outputs[0] = model_out;
outputs[0] = model_out.release();
return Status::OK();
}
// Else, reduce sum the output to get a scalar
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
AbstractTensorHandlePtr sum_dims =
GetRangeTensorHandleUtil(ctx, num_dims_out);
AbstractTensorHandlePtr sum_dims;
{
vector<int> vals(num_dims_out);
int64_t vals_shape[] = {num_dims_out};
Range(&vals, 0, num_dims_out);
AbstractTensorHandle* sum_dims_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsInt(ctx, vals.data(), vals_shape,
1, &sum_dims_raw));
sum_dims.reset(sum_dims_raw);
}
// Reduce sum the output on all dimensions.
std::vector<AbstractTensorHandle*> sum_inputs(2);
sum_inputs[0] = model_out;
sum_inputs[1] = sum_dims.get();
TF_RETURN_IF_ERROR(
ops::Sum(ctx, sum_inputs, absl::MakeSpan(model_outputs), "sum_output"));
outputs[0] = model_outputs[0];
ops::Sum(ctx, {model_out.get(), sum_dims.get()}, outputs, "sum_output"));
return Status::OK();
}
// ========================= End Helper Functions==============================
@ -144,61 +126,77 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
// Numerical Grad Check
for (int i = 0; i < num_elems; i++) {
// Get relative epsilon value
float epsilon =
std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0
AbstractTensorHandlePtr two_eps =
GetScalarTensorHandleUtil(ctx, 2 * epsilon);
float epsilon = theta_data[i] == 0 ? 1e-4 : std::abs(theta_data[i] * 1e-4);
AbstractTensorHandlePtr two_eps;
{
AbstractTensorHandle* two_eps_raw = nullptr;
TF_RETURN_IF_ERROR(
TestScalarTensorHandle(ctx, 2 * epsilon, &two_eps_raw));
two_eps.reset(two_eps_raw);
}
// Initialize theta[i] + epsilon.
memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaPlus_data[i] += epsilon;
AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims);
AbstractTensorHandlePtr thetaPlus;
{
AbstractTensorHandle* thetaPlus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims,
&thetaPlus_raw));
thetaPlus.reset(thetaPlus_raw);
}
// Initialize theta[i] - epsilon.
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaMinus_data[i] -= epsilon;
AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims);
AbstractTensorHandlePtr thetaMinus;
{
AbstractTensorHandle* thetaMinus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims,
&thetaMinus_raw));
thetaMinus.reset(thetaMinus_raw);
}
// Get f(theta + eps):
theta_inputs[input_index] = thetaPlus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fPlus = f_outputs[0];
AbstractTensorHandlePtr fPlus(f_outputs[0]);
// Get f(theta - eps):
theta_inputs[input_index] = thetaMinus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fMinus = f_outputs[0];
AbstractTensorHandlePtr fMinus(f_outputs[0]);
// Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
TF_RETURN_IF_ERROR(
ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"));
AbstractTensorHandle* fDiff = f_outputs[0];
TF_RETURN_IF_ERROR(ops::Sub(ctx, {fPlus.get(), fMinus.get()},
absl::MakeSpan(f_outputs), "sub_top"));
AbstractTensorHandlePtr fDiff(f_outputs[0]);
// Calculate using the difference quotient definition:
// (f(theta + eps) - f(theta - eps)) / (2 * eps).
TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()},
absl::MakeSpan(f_outputs),
"diff_quotient"));
AbstractTensorHandle* diff_quotient = f_outputs[0];
TF_RETURN_IF_ERROR(ops::Div(ctx, {fDiff.get(), two_eps.get()},
absl::MakeSpan(f_outputs), "diff_quotient"));
AbstractTensorHandlePtr diff_quotient(f_outputs[0]);
TF_Tensor* grad_tensor;
TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor));
TF_RETURN_IF_ERROR(GetValue(diff_quotient.get(), &grad_tensor));
float grad_data[1];
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
TF_TensorByteSize(grad_tensor));
TF_DeleteTensor(grad_tensor);
dtheta_approx[i] = grad_data[0];
}
// Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
TF_DeleteTensor(theta_tensor);
return Status::OK();
}

View File

@ -12,23 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
#define TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
namespace tensorflow {
namespace gradients {
@ -51,3 +42,5 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_

View File

@ -15,21 +15,11 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -37,6 +27,59 @@ namespace gradients {
namespace internal {
namespace {
using tensorflow::TF_StatusPtr;
void CompareNumericalAndManualGradients(
Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, int input_index,
float* expected_grad, int num_grad, bool use_function,
double abs_error = 1e-2) {
Status s;
AbstractTensorHandlePtr numerical_grad;
{
AbstractTensorHandle* numerical_grad_raw;
s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function,
&numerical_grad_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
numerical_grad.reset(numerical_grad_raw);
}
TF_Tensor* numerical_tensor;
s = GetValue(numerical_grad.get(), &numerical_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
ASSERT_EQ(num_elem_numerical, num_grad);
float* dnumerical = new float[num_elem_numerical]{0};
memcpy(&dnumerical[0], TF_TensorData(numerical_tensor),
TF_TensorByteSize(numerical_tensor));
for (int j = 0; j < num_grad; j++) {
ASSERT_NEAR(dnumerical[j], expected_grad[j], abs_error);
}
delete[] dnumerical;
TF_DeleteTensor(numerical_tensor);
}
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::MatMul(ctx, inputs, outputs, "MatMul",
/*transpose_a=*/false,
/*transpose_b=*/false);
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Mul(ctx, inputs, outputs, "Mul");
}
// TODO(vnvo2409): Add more tests from `python/ops/gradient_checker_v2_test.py`.
// These tests should not be confused with `[*]_grad_test` which compare the
// result of `gradient_checker` and `[*]_grad`. The tests here test the
// functionality of `gradient_checker` by comparing the result with expected
// manual user-provided gradients.
class GradientCheckerTest
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
@ -45,84 +88,56 @@ class GradientCheckerTest
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx_.reset(ctx_raw);
}
}
AbstractContextPtr ctx_;
public:
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
bool UseFunction() const { return std::get<2>(GetParam()); }
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
// Computing numerical gradients with TensorFloat-32 is numerically unstable
enable_tensor_float_32_execution(false);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
TEST_P(GradientCheckerTest, TestMatMul) {
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
A.reset(A_raw);
}
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(A.get());
inputs.push_back(B.get());
AbstractTensorHandle* grad_approx;
Status s = CalcNumericalGrad(
ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &grad_approx);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(grad_approx, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
AbstractTensorHandlePtr B;
{
AbstractTensorHandle* B_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
B.reset(B_raw);
}
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-2;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(expected_dA[j], result_data[j], tolerance);
}
TF_DeleteTensor(gt);
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
MatMulModel, ctx_.get(), {A.get(), B.get()}, 0, expected_dA, 4,
UseFunction()));
}
TEST_P(GradientCheckerTest, TestGradCheckMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
TEST_P(GradientCheckerTest, TestMul) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
@ -130,124 +145,15 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
// Will perform z = x*y.
// dz/dx = y
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(x.get());
inputs.push_back(y.get());
AbstractTensorHandle* g;
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[1] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2);
TF_DeleteTensor(gt);
}
TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
/** Test to show how to use this API with analytical gradients:
*
* We have `SoftmaxLossGradModel`, which is a wrapper for the
* Softmax analytical gradient found in c/experimental/nn_grads.
*
* We will use the GradientChecker by applying finite differences
* to the forward pass wrapped in `SoftmaxModel` and verify that
* both the analytical and numerical gradients are relatively
* close.
*
*/
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = scores
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// y = labels
int y_vals[] = {1, 0, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(X.get());
inputs.push_back(y.get());
// Run analytical gradient and get its data.
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs),
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float danalytical[9] = {0}; // Contains data from analytical gradient.
memcpy(&danalytical[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
// Run numerical gradient approximation using the GradientChecker API.
AbstractTensorHandle* g; // Will contain numerical approximation data.
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float dnumerical[9] = {0};
memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt));
// Now compare the two implementations:
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2);
}
// Only Unref() first output as 2nd is nullptr grad for labels
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
TF_DeleteTensor(gt);
float expected_dx[1] = {7.0f};
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
MulModel, ctx_.get(), {x.get(), y.get()}, 0, expected_dx, 1,
UseFunction()));
}
#ifdef PLATFORM_GOOGLE
@ -255,13 +161,13 @@ INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
/*use_function*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
/*use_function*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal

View File

@ -111,11 +111,11 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Same as `AddFunctionDef`, and additionally saves a pointer to the Graph
// which has nodes containing stack traces for the nodes in `fdef`. Assumes
// `graph` is alive while the function is alive.
virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef,
const Graph* graph) = 0;
// Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
// the key of the function definition name (to be retrieved during function
// instantiation).
virtual Status AddFunctionDefWithStackTraces(
const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;

View File

@ -395,80 +395,6 @@ TEST_P(CppGradients, TestReluGrad) {
TF_DeleteTensor(dX_tensor);
}
TEST_P(CppGradients, TestSoftmaxLossGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = scores
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// y = labels
int y_vals[] = {1, 0, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(X)
* tape.watch(labels)
* loss = SoftmaxLoss(X, labels)
* outputs = tape.gradient(loss, [X, labels])
*
*
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[9] = {0};
memcpy(&result_data[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
float expected_dX[9] = {0.090f, -0.7553f, 0.6652f, -0.9099f, 0.2447f,
0.6652f, 0.8437f, -0.8858f, 0.0420f};
float tolerance = 1e-3;
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
}
// Only Unref() first output as 2nd is nullptr grad for labels
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
}
TEST_P(CppGradients, TestMNISTGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace gradients {
namespace internal {
@ -184,27 +183,6 @@ Status ReluGradModel(AbstractContext* ctx,
return Status::OK();
}
Status SoftmaxLossGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
auto tape = new Tape(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch scores.
tape->Watch(inputs[1]); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
/*targets=*/sm_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
delete tape;
return Status::OK();
}
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
@ -283,14 +261,6 @@ Status MulModel(AbstractContext* ctx,
"mul0"); // Compute x*y
}
Status SoftmaxModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
return ops::SparseSoftmaxCrossEntropyWithLogits(ctx, inputs, outputs,
"sm_loss");
}
// ============================= End Models ================================
} // namespace internal

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace gradients {
namespace internal {
@ -68,12 +67,6 @@ Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify SoftmaxGrad functionality
Status SoftmaxLossGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify Multi-grad functionality for MNIST
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
@ -96,11 +89,6 @@ Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
Status SoftmaxModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -119,7 +119,7 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) {
{
AbstractTensorHandle* x_raw = nullptr;
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
int64 dim_sizes[] = {2, 4};
int64_t dim_sizes[] = {2, 4};
Status s =
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();

View File

@ -144,18 +144,43 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
}
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64* dims, int num_dims,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return StatusFromTF_Status(status.get());
}
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@ -54,8 +55,16 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
// Get a Matrix TensorHandle with given float values and dimensions.
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64* dims, int num_dims,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor);
// Places data from `t` into *result_tensor.
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_

View File

@ -81,7 +81,7 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok,
return;
}
size_t bucket_end = fname.find("/", scheme_end + 1);
size_t bucket_end = fname.find('/', scheme_end + 1);
if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name.");

View File

@ -38,7 +38,7 @@ void ParseHadoopPath(const std::string& fname, std::string* scheme,
size_t scheme_end = fname.find("://") + 2;
// We don't want `://` in scheme.
*scheme = fname.substr(0, scheme_end - 2);
size_t nn_end = fname.find("/", scheme_end + 1);
size_t nn_end = fname.find('/', scheme_end + 1);
if (nn_end == std::string::npos) {
*namenode = fname.substr(scheme_end + 1);
*path = "";

View File

@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# buildifier: disable=same-origin-load
load(
"//tensorflow:tensorflow.bzl",
"if_libtpu",
"tf_cuda_cc_test",
)
load(
@ -165,7 +166,7 @@ cc_library(
],
deps = [
"//tensorflow/c/eager:gradient_checker",
"//tensorflow/c/eager:gradients_util",
"//tensorflow/c/eager:unified_api_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
@ -183,9 +184,14 @@ tf_cuda_cc_test(
deps = [
":grad_test_helper",
":nn_grad",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
)

View File

@ -86,16 +86,6 @@ Status ExpWithPassThroughGrad(AbstractContext* ctx,
return Status::OK();
}
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -128,7 +118,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
s = GetValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);

View File

@ -24,24 +24,28 @@ namespace internal {
void CompareNumericalAndAutodiffGradients(
Model model, Model grad_model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
const GradientRegistry& registry, double abs_error) {
double abs_error) {
auto num_inputs = inputs.size();
std::vector<AbstractTensorHandle*> outputs(num_inputs);
auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs),
/*use_function=*/use_function, registry);
/*use_function=*/use_function);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
for (int i = 0; i < num_inputs; ++i) {
if (!outputs[i]) continue;
AbstractTensorHandle* g; // Will contain numerical approximation data.
s = CalcNumericalGrad(ctx, model, inputs,
/*input_index=*/i,
/*use_function=*/use_function, &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
AbstractTensorHandlePtr numerical_grad;
{
AbstractTensorHandle* numerical_grad_raw;
s = CalcNumericalGrad(ctx, model, inputs,
/*input_index=*/i, use_function,
&numerical_grad_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
numerical_grad.reset(numerical_grad_raw);
}
TF_Tensor* numerical_tensor;
s = GetValue(g, &numerical_tensor);
s = GetValue(numerical_grad.get(), &numerical_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);

View File

@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
namespace tensorflow {
namespace gradients {
@ -24,7 +24,7 @@ namespace internal {
void CompareNumericalAndAutodiffGradients(
Model model, Model grad_model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
const GradientRegistry& registry, double abs_error = 1e-2);
double abs_error = 1e-2);
} // namespace internal
} // namespace gradients

View File

@ -15,8 +15,11 @@ limitations under the License.
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/gradients/grad_test_helper.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -26,17 +29,60 @@ namespace {
using tensorflow::TF_StatusPtr;
Status SparseSoftmaxCrossEntropyWithLogitsModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, inputs, absl::MakeSpan(temp_outputs),
"SparseSoftmaxCrossEntropyWithLogits"));
// `gradient_checker` only works with model that returns only 1 tensor.
// Although, `ops::SparseSoftmaxCrossEntropyWithLogits` returns 2 tensors, the
// second tensor isn't needed for computing gradient so we could safely drop
// it.
outputs[0] = temp_outputs[0];
temp_outputs[1]->Unref();
return Status::OK();
}
Status SparseSoftmaxCrossEntropyWithLogitsGradModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(
registry.Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]); // Watch score.
tape.Watch(inputs[1]); // Watch label.
std::vector<AbstractTensorHandle*> temp_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs),
"SparseSoftmaxCrossEntropyWithLogitsGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status BiasAddModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd");
}
Status BiasAddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("BiasAdd", BiasAddRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]); // Watch A.
tape.Watch(inputs[1]); // Watch Bias.
@ -54,11 +100,6 @@ Status BiasAddGradModel(AbstractContext* ctx,
return Status::OK();
}
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer));
return Status::OK();
}
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
@ -66,7 +107,7 @@ class CppGradients
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
{
AbstractContext* ctx_raw = nullptr;
@ -75,12 +116,8 @@ class CppGradients
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx_.reset(ctx_raw);
}
s = RegisterGradients(&registry_);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
GradientRegistry registry_;
AbstractContextPtr ctx_;
public:
@ -88,6 +125,43 @@ class CppGradients
bool UseFunction() const { return std::get<2>(GetParam()); }
};
TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
if (UseFunction()) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
// Score
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3};
AbstractTensorHandlePtr X;
{
AbstractTensorHandle* X_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
X.reset(X_raw);
}
// Label
int Y_vals[] = {1, 0, 1};
int64_t Y_dims[] = {3};
AbstractTensorHandlePtr Y;
{
AbstractTensorHandle* Y_raw;
Status s =
TestTensorHandleWithDimsInt(ctx_.get(), Y_vals, Y_dims, 1, &Y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
Y.reset(Y_raw);
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
SparseSoftmaxCrossEntropyWithLogitsModel,
SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
{X.get(), Y.get()},
/*use_function=*/UseFunction()));
}
TEST_P(CppGradients, TestBiasAddGrad) {
if (UseFunction() && UseMlir()) {
GTEST_SKIP() << "SetAttrString has not been implemented yet.\n";
@ -96,19 +170,29 @@ TEST_P(CppGradients, TestBiasAddGrad) {
// A
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx_.get(), A_vals, A_dims, 2);
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
A.reset(A_raw);
}
// Bias
float Bias_vals[] = {2.0f, 3.0f};
int64_t Bias_dims[] = {2};
AbstractTensorHandlePtr Bias =
GetTensorHandleUtilFloat(ctx_.get(), Bias_vals, Bias_dims, 1);
std::vector<AbstractTensorHandle*> inputs{A.get(), Bias.get()};
AbstractTensorHandlePtr Bias;
{
AbstractTensorHandle* Bias_raw;
Status s = TestTensorHandleWithDimsFloat(ctx_.get(), Bias_vals, Bias_dims,
1, &Bias_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
Bias.reset(Bias_raw);
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
/*use_function=*/UseFunction(), registry_));
/*use_function=*/UseFunction()));
}
#ifdef PLATFORM_GOOGLE

View File

@ -0,0 +1,17 @@
# Description:
# test for stream_executor
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_shared_object",
)
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_shared_object(
name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"],
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
)

View File

@ -0,0 +1,23 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
using tensorflow::errors::InvalidArgument;
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
// and with a focus on maintaining both source and binary compatibility.
@ -87,9 +88,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
TF_SetStatus(status, TF_OK, "");
}
#undef CASE
} // namespace
} // namespace tensorflow
namespace {
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
const tensorflow::AttrValue* attr =
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
if (attr == nullptr) {
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
"' has no attr named '", attr_name, "'.");
}
return attr;
}
} // namespace
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
const char* attr_name,
const TF_DataType type,
@ -257,7 +274,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
cc_ctx->CtxFailure(s);
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
const char* attr_name,
int32_t* list_size,
int32_t* total_size,
TF_Status* status) {
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
if (!status->status.ok()) {
*list_size = -1;
*total_size = -1;
return;
}
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
*list_size = -1; \
*total_size = size_expr; \
break;
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
SINGLE_CASE(kI, TF_ATTR_INT, -1);
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
#undef SINGLE_CASE
case tensorflow::AttrValue::kList:
*list_size = 0;
*total_size = -1;
#define LIST_CASE(field, attr_type, ...) \
if (attr->list().field##_size() > 0) { \
*list_size = attr->list().field##_size(); \
__VA_ARGS__; \
break; \
}
LIST_CASE(
s, TF_ATTR_STRING, *total_size = 0;
for (int i = 0; i < attr->list().s_size();
++i) { *total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
LIST_CASE(
shape, TF_ATTR_SHAPE, *total_size = 0;
for (int i = 0; i < attr->list().shape_size(); ++i) {
const auto& s = attr->list().shape(i);
*total_size += s.unknown_rank() ? 0 : s.dim_size();
});
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
break;
case tensorflow::AttrValue::kPlaceholder:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::kFunc:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::VALUE_NOT_SET:
status->status =
InvalidArgument("Attribute '", attr_name, "' has no value set");
break;
}
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
const char* attr_name, \
c_type* val, TF_Status* status) { \
@ -269,10 +360,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
} \
void TF_OpKernelConstruction_GetAttr##func##List( \
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
int max_vals, TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
status->status = \
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
if (!status->status.ok()) return; \
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
DEFINE_TF_GETATTR(Float, float, float, "float", f)
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
const char* attr_name, char* value,
size_t max_length,
TF_Status* status) {
std::string v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
if (max_length <= 0) {
return;
}
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
}
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
const char* attr_name,
char** values, size_t* lengths,
int max_values, void* storage,
size_t storage_size,
TF_Status* status) {
std::vector<std::string> v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(v.size()));
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
const std::string& s = v[i];
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
status->status = InvalidArgument(
"Not enough storage to hold the requested list of strings");
return;
}
memcpy(values[i], s.data(), s.size());
p += s.size();
}
}
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
const char* attr_name, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
return cc_ctx->HasAttr(attr_name);
}
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);

View File

@ -184,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
// Returns the step ID of the given context.
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
// Get the list_size and total_size of the attribute `attr_name` of `oper`.
// list_size - the length of the list.
// total_size - total size of the list.
// (1) If attr_type == TF_ATTR_STRING
// then total_size is the cumulative byte size
// of all the strings in the list.
// (3) If attr_type == TF_ATTR_SHAPE
// then total_size is the number of dimensions
// of the shape valued attribute, or -1
// if its rank is unknown.
// (4) If attr_type == TF_ATTR_SHAPE
// then total_size is the cumulative number
// of dimensions of all shapes in the list.
// (5) Otherwise, total_size is undefined.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
int32_t* total_size, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType and
// places it into *val. *status is set to TF_OK.
//
@ -202,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
TF_Status* status);
// Interprets the named kernel construction attribute as int64_t and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// int64, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
TF_Status* status);
// Interprets the named kernel construction attribute as float and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// float, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
TF_Status* status);
// Interprets the named kernel construction attribute as bool and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// bool, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
TF_Status* status);
// Interprets the named kernel construction attribute as string and
// places it into *val. `val` must
// point to an array of length at least `max_length` (ideally set to
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
// attr_name, list_size, total_size)). *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// string, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
size_t max_length, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int32_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int64_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as float array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as bool array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as string array and fills
// in `vals` and `lengths`, each of which must point to an array of length at
// least `max_values`. *status is set to TF_OK. The elements of values will
// point to addresses in `storage` which must be at least `storage_size` bytes
// in length. Ideally, max_values would be set to list_size and `storage` would
// be at least total_size, obtained from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
size_t* lengths, int max_values, void* storage, size_t storage_size,
TF_Status* status);
// Return true if the kernel construction has the attr_name
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
// Returns the unique operation name for this OpKernel.
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
TF_OpKernelConstruction* ctx);

View File

@ -161,6 +161,336 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
ASSERT_TRUE(delete_called);
}
// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
// Registers two ops, each with a single attribute called 'Attr'.
// The attribute in one op will have a type 'type', the other
// will have list(type).
#define ATTR_TEST_REGISTER_OP(name, type) \
REGISTER_OP("TestKernelAttr" #name) \
.Attr("Attr: " #type) \
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
REGISTER_OP("TestKernelAttr" #name "List") \
.Attr("Attr: list(" #type ")") \
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
ATTR_TEST_REGISTER_OP(String, string);
ATTR_TEST_REGISTER_OP(Int, int);
ATTR_TEST_REGISTER_OP(Float, float);
ATTR_TEST_REGISTER_OP(Bool, bool);
ATTR_TEST_REGISTER_OP(Type, type);
#undef ATTR_TEST_REGISTER_OP
// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
do { \
int32_t list_size, total_size; \
TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
&total_size, status); \
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
EXPECT_EQ(expected_list_size, list_size); \
EXPECT_EQ(expected_total_size, total_size); \
} while (0)
typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
class TestKernelAttr : public ::testing::Test {
public:
TestKernelAttr() {}
~TestKernelAttr() override {}
std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
AttrValue v, Status* status) {
NodeDef def;
def.set_op(op_name);
def.set_name("FakeNode");
def.set_device("FakeDevice");
(*def.mutable_attr())["Attr"] = v;
return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
status);
}
void CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,
const char* op_name, AttrValue& v) {
TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder("FakeNode", builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernelWithAttr(op_name, v, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
kernel->Compute(nullptr);
ASSERT_TRUE(delete_called);
}
};
TEST_F(TestKernelAttr, String) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::unique_ptr<char[]> val(new char[5]);
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ 5);
TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
/*max_length*/ 5, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_s("bunny");
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrString", v);
}
TEST_F(TestKernelAttr, StringList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::vector<string> list = {"bugs", "bunny", "duck"};
int list_total_size = 0;
for (const auto& s : list) {
list_total_size += s.size();
}
TF_Status* status = TF_NewStatus();
std::unique_ptr<char*[]> values(new char*[list.size()]);
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
std::unique_ptr<char[]> storage(new char[list_total_size]);
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
/*expected_total_size*/ list_total_size);
TF_OpKernelConstruction_GetAttrStringList(
ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
list_total_size, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (size_t i = 0; i < list.size(); ++i) {
EXPECT_EQ(list[i].size(), lens[i]) << i;
EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
<< i;
}
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
std::string attr_in[] = {"bugs", "bunny", "duck"};
SetAttrValue(gtl::ArraySlice<std::string>(attr_in, 3), &v);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrStringList", v);
}
TEST_F(TestKernelAttr, Int) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
int64_t val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1234, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_i(1234);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrInt", v);
}
TEST_F(TestKernelAttr, IntList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const int64_t list[] = {1, 2, 3, 4};
const size_t list_size = TF_ARRAYSIZE(list);
int64_t values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
int64 attr_in[] = {1, 2, 3, 4};
SetAttrValue(gtl::ArraySlice<int64>(attr_in, 4), &v);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrIntList", v);
}
TEST_F(TestKernelAttr, Float) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
float val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FLOAT_EQ(2.718, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_f(2.718);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloat", v);
}
TEST_F(TestKernelAttr, FloatList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const float list[] = {1.414, 2.718, 3.1415};
const size_t list_size = TF_ARRAYSIZE(list);
float values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
float attr_in[] = {1.414, 2.718, 3.1415};
SetAttrValue(gtl::ArraySlice<float>(attr_in, 3), &v);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloatList", v);
}
TEST_F(TestKernelAttr, Bool) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
unsigned char val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_b(true);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBool", v);
}
TEST_F(TestKernelAttr, BoolList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const unsigned char list[] = {1, 0, 1, 0};
const size_t list_size = TF_ARRAYSIZE(list);
unsigned char values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
bool attr_in[] = {true, false, true, false};
SetAttrValue(gtl::ArraySlice<bool>(attr_in, 4), &v);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBoolList", v);
}
TEST_F(TestKernelAttr, Type) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
TF_DataType val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_FLOAT, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_type(DT_FLOAT);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrType", v);
}
TEST_F(TestKernelAttr, TypeList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
const size_t list_size = TF_ARRAYSIZE(list);
TF_DataType values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
DataType attr_in[] = {DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128};
SetAttrValue(gtl::ArraySlice<DataType>(attr_in, 4), &v);
CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTypeList", v);
}
#undef EXPECT_TF_SIZE
class DummyDevice : public DeviceBase {
public:
explicit DummyDevice(Env* env) : DeviceBase(env) {}

View File

@ -60,7 +60,7 @@ string GetPath(const string& dot_h_fname) {
if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
result = result.substr(sizeof("external/") - 1);
pos = result.find("/");
pos = result.find('/');
if (pos != string::npos) {
result = result.substr(pos + 1);
}

View File

@ -184,6 +184,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -151,10 +151,11 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
// not considered uncompilable.
if (node_stack_trace != nullptr) {
for (const auto& frame : *node_stack_trace) {
stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
stack_trace.emplace_back(
StackFrameView{frame.name, frame.function_name, frame.n});
}
}
stack_trace.emplace_back(StackFrameView{node.name(), ""});
stack_trace.emplace_back(StackFrameView{node.name(), "", &node});
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
IsCompilableNode(node, lib_runtime, &stack_trace,
@ -173,10 +174,11 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
std::vector<StackFrameView> stack_trace;
if (node_stack_trace != nullptr) {
for (const auto& frame : *node_stack_trace) {
stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
stack_trace.emplace_back(
StackFrameView{frame.name, frame.function_name, frame.n});
}
}
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr});
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
IsCompilableCall(call_def, lib_runtime, &stack_trace,
@ -194,12 +196,11 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(
"SymbolicGradient should be handled by IsCompilableCall().";
return false;
}
if (node.type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
// registered Const KernelDef says that it does, to support no-op Assert for
// tfcompile.
const AttrValue* attr = node.attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
if (!op_filter_.allow_string_consts && attr != nullptr &&
attr->type() == DT_STRING) {
*uncompilable_reason =
"Const op with type DT_STRING is not supported by XLA.";
return false;
@ -359,7 +360,8 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
bool is_compilable = true;
for (const Node* node : fbody->graph->op_nodes()) {
stack_trace->emplace_back(StackFrameView{node->name(), function.name()});
stack_trace->emplace_back(
StackFrameView{node->name(), function.name(), node});
is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
&function, uncompilable_nodes);
stack_trace->pop_back();
@ -583,7 +585,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
[](const StackFrameView& stack_element) {
return StackFrame{
std::string(stack_element.name),
std::string(stack_element.function_name)};
std::string(stack_element.function_name),
stack_element.n};
});
node_info.name = std::string(stack_trace.back().name);

View File

@ -62,6 +62,7 @@ class RecursiveCompilabilityChecker {
struct StackFrame {
std::string name;
std::string function_name;
const Node* n = nullptr;
};
// Contains information about uncompilable node inside a function body.
@ -128,6 +129,9 @@ class RecursiveCompilabilityChecker {
// Require the function to be always compilable, regardless whether some
// control flow branches might be dead for a given input.
bool require_always_compilable = false;
// Whether string constants are compilable.
bool allow_string_consts = true;
};
RecursiveCompilabilityChecker(OperationFilter op_filter,
@ -193,6 +197,7 @@ class RecursiveCompilabilityChecker {
struct StackFrameView {
absl::string_view name;
absl::string_view function_name;
const Node* n = nullptr;
};
bool IsCompilableNode(

View File

@ -177,6 +177,7 @@ void AllocateAndParseFlags() {
// bridge, on a per-graph basis).
bool enable_mlir_bridge = false;
bool enable_mlir_bridge_is_explicit = false;
bool mlir_bridge_safe_mode = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
@ -227,7 +228,13 @@ void AllocateAndParseFlags() {
Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
&enable_mlir_bridge_is_explicit)});
&enable_mlir_bridge_is_explicit),
Flag(
"tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
"When tf_mlir_enable_mlir_bridge is true, this field can enable "
"the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
"it only runs for graphs that use features MLIR bridge currently "
"supports.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -238,7 +245,9 @@ void AllocateAndParseFlags() {
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
} else if (enable_mlir_bridge) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
(mlir_bridge_safe_mode)
? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
: ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
} else {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;

View File

@ -1199,6 +1199,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
RecursiveCompilabilityChecker::OperationFilter filter =
CreateOperationFilter(*registration);
filter.require_always_compilable = true;
filter.allow_string_consts = false;
RecursiveCompilabilityChecker checker(
filter, DeviceType{registration->compilation_device_name});
@ -1207,6 +1208,15 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue;
}
if (node->type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA autoclustering doesn't
// support it.
const AttrValue* attr = node->attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
continue;
}
}
if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
VLOG(1) << "Rejecting TF operation " << node->def().op()
<< " as it is not listed in --tf_xla_ops_to_cluster.";
@ -2035,6 +2045,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"TensorScatterUpdate",
"TridiagonalSolve",
"TruncatedNormal",
"Unique",
"UpperBound",
"UnsortedSegmentMax",
"UnsortedSegmentMin",

View File

@ -38,7 +38,7 @@ class XlaCpuDeviceFactory : public DeviceFactory {
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}

View File

@ -96,7 +96,7 @@ Status XlaGpuDeviceFactory::CreateDevices(
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}

View File

@ -115,17 +115,23 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
uncompilable_node_info.emplace_back(info);
}
}
string message = absl::StrCat(
std::string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
absl::StrAppend(&message, "Uncompilable operations:");
for (const auto& node_info : uncompilable_node_info) {
string node_message = absl::StrCat("\n", node_info.name, ": ",
node_info.uncompilable_reason, "\n",
"\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name);
std::string node_message = absl::StrCat(
"\n", node_info.name, ": ", node_info.uncompilable_reason, "\n",
"The op is created at:\n");
const Node* n = node_info.stack_trace.back().n;
if (n && n->GetStackTrace()) {
AbstractStackTrace::TracePrintingOptions opts;
opts.show_line_contents = true;
opts.filter_common_prefix = true;
opts.drop_internal_frames = true;
absl::StrAppend(&node_message, n->GetStackTrace()->ToString(opts));
} else {
absl::StrAppend(&node_message, "<Unavailable>\n");
}
absl::StrAppend(&message, node_message);
}

View File

@ -79,6 +79,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"//tensorflow/core:lib",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
@ -112,8 +113,8 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
"//tensorflow/compiler/mlir/tosa:tf_tosa_passes",
"//tensorflow/compiler/mlir/tosa:tfl_tosa_passes",
"//tensorflow/compiler/mlir/tosa:tf_passes",
"//tensorflow/compiler/mlir/tosa:tfl_passes",
],
)

View File

@ -1,7 +1,273 @@
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
### `-tf-device-constant-sinking`: Sinks constants implicitly captured in a tf_device.cluster region.
This pass sinks implicitly captured constants (`tf.Const` ops) used by and into
a `tf_device.cluster` region. Performing this prior to outlining will reduce the
number of arguments of the outlined function.
For example, the following:
```mlir
func @cluster() -> tensor<i32> {
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
```
will be transformed into:
```mlir
func @cluster() -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
```
### `-tf-executor-graph-pruning`: Prunes unreachable ops in a tf_executor.graph
This pass removes ops from a `tf_executor.graph` that are not transitively, via
data or control dependencies, connected to the associated `tf_executor.fetch`
op. The order of ops will be preserved. Functions named `main` with no
`tf.entry_function` attribute will not be pruned, as such graphs/functions may
have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are
not provided at certain stages of IR transformation (e.g. pre-placement).
For example, the following:
```mlir
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%graph = tf_executor.graph {
%transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
%unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
%reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
%unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32>
tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
}
return %graph : tensor<i32>
}
```
will be transformed into:
```mlir
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%graph = tf_executor.graph {
%transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
%transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
%reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
}
return %graph : tensor<i32>
}
```
### `-tf-executor-to-functional-conversion`: Lifts tf_executor.island inner ops from a tf_executor.graph
This pass converts tf_executor.graphs consisting of only tf_executor.islands and
a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by
lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control
flow ops are present in a tf_executor.graph, an error will be returned.
For example, the following:
```mlir
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%graph_results:2 = tf_executor.graph {
%island_0_result, %island_0_control = tf_executor.island {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_executor.yield %identity : tensor<i32>
}
%island_1_result, %island_1_control = tf_executor.island {
%identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
tf_executor.yield %identity_n#0
}
tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32>
}
return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32>
}
```
will be transformed into:
```mlir
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %identity, %identity_n#0 : tensor<i32>, tensor<i32>
}
```
### `-tf-mark-ops-for-outside-compilation`: Marks ops in device cluster for outside compilation if they are unsupported on device.
This pass marks unsupported ops in a device cluster with
`_xla_outside_compilation` attribute so the operations will run on the host
instead of the device. Unsupported ops are ops that can not be code
generated to run on the device for the cluster including:
1. String operations on TPUs.
2. Operations that don't have a kernel defined for the device.
This pass is conservative in that it will mark all ops for outside compilation
that can not be compiled for the device. Exceptions for this are added for ops
that will be rewritten or decomposed before compiling on device.
For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:
```mlir
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.UnsupportedOp"() : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
return %0 : tensor<i32>
}
```
will mark tf.UnsupportedOp with `_xla_outside_compilation` attribute:
```mlir
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32>
return %0 : tensor<i32>
}
```
### `-tf-shape-inference`: Simple Shape Inference on TensorFlow Dialect
#### Options
```
-max-iterations : Maximum shape inference iterations
```
### `-tf-tpu-cluster-formation`: Forms clusters from operations assigned to the same TPU computation
TPU computations from the frontend are composed of a `tf.TPUReplicateMetadata`
op, a subgraph of ops (TensorFlow Dialect) each with a matching `_tpu_replicate`
attribute relative to the associated `tf.TPUReplicateMetadata` op, and
optionally `tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops feeding in
inputs and outputs to and from a replicated TPU computation. The number of times
a TPU computation is replicated is defined in the `tf.TPUReplicateMetadata` op
(`num_replicas` attribute) and operand and result sizes of
`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` respectively must match,
excluding packed tensors. It is also assumed ops of the same TPU computation do
not have ops outside of the TPU computation that are both inputs and outputs to
the same TPU computation.
This pass takes the TPU computation subgraph, moves them into a
`tf_device.cluster`, and copies over attributes from the associated
`tf.TPUReplicateMetadata` op to the newly created `tf_device.cluster`. If the
computation is replicated (`num_replicas` > 1), the `num_replicas` attribute is
not copied over but instead the `tf_device.cluster` is further wrapped with a
`tf_device.replicate`, and associated `tf.TPUReplicatedInput` and
`tf.TPUReplicatedOutput` ops are replaced as the `tf_device.replicate` operands
and results. Otherwise, the single operands and results of the associated
`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops are simply forwarded to
the `tf_device.cluster`.
For example, the following non replicated computation:
```mlir
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
// Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
// with topology `<topology>`.
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
%replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
%identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
%replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
return %replicated_output : tensor<i32>
}
```
will be transformed into:
```mlir
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
```
The following replicated computation:
```mlir
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
%replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
%replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
}
```
will be transformed into:
```mlir
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
tf_device.return %cluster : tensor<i32>
}
return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
}
```
### `-tf-tpu-extract-outside-compilation`: Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.
This pass extracts a CPU computation cluster with `_xla_outside_compilation`
annotation, which denotes ops that should be run on CPU/host, from a TPU cluster.
Each outside compilation cluster is moved to
a tf_device.parallel_execute region. The TPU cluster is also moved to a
tf_device.parallel_execute region. Communication ops between device and host are
added to pass inputs/outputs to/from the outside compiled region.
For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`:
```mlir
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
%2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
%3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
tf_device.return %3 : tensor<f32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
return %0 : tensor<f32>
}
```
will become a tf_device.parallel_execute op with a CPU/host region and
a tf_device.cluster with communication ops to send data to/from device/host:
```mlir
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.parallel_execute"() ( {
"tf_device.launch"() ( {
%1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string>
%2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32>
%3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
"tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
tf_device.return
}, {
%1 = "tf_device.cluster"() ( {
%2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
%4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
tf_device.return %4 : tensor<f32>
}) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
tf_device.return %1 : tensor<f32>
}) : () -> tensor<f32>
return %0 : tensor<f32>
}
```

View File

@ -0,0 +1,15 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
build --cxxopt=-std=c++14
build --host_cxxopt=-std=c++14

View File

@ -1,4 +1,4 @@
build
llvm-project
llvm-build
bazel-*

View File

@ -570,7 +570,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:TensorDialect",
],
)
@ -740,6 +740,7 @@ cc_library(
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
@ -809,11 +810,11 @@ cc_library(
deps = [
":hlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
alwayslink = 1,
)

View File

@ -0,0 +1,57 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Workspace for MLIR HLO."""
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "<LLVM_COMMIT>"
LLVM_SHA256 = "<LLVM_SHA256>"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)
http_archive(
name = "llvm-bazel",
strip_prefix = "llvm-bazel-{tag}/llvm-bazel".format(tag = LLVM_BAZEL_TAG),
url = "https://github.com/google/llvm-bazel/archive/{tag}.tar.gz".format(tag = LLVM_BAZEL_TAG),
)
load("@llvm-bazel//:terminfo.bzl", "llvm_terminfo_disable")
load("@llvm-bazel//:zlib.bzl", "llvm_zlib_disable")
load("@llvm-bazel//:configure.bzl", "llvm_configure")
http_archive(
name = "llvm-project-raw",
build_file_content = "#empty",
sha256 = LLVM_SHA256,
strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT),
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
],
)
llvm_terminfo_disable(
name = "llvm_terminfo",
)
llvm_zlib_disable(
name = "llvm_zlib",
)
llvm_configure(
name = "llvm-project",
src_path = ".",
src_workspace = "@llvm-project-raw//:WORKSPACE",
)

View File

@ -18,12 +18,12 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

View File

@ -21,13 +21,13 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -146,10 +146,9 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
// Abs supports complex to real, so element type is not guaranteed to match.
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape],
[NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
let builders = [
OpBuilderDAG<(ins "Value":$operand)>];
}
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
@ -902,6 +901,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
ConvolutionAttributes.attributes);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {

View File

@ -958,6 +958,17 @@ def HLO_PrecisionConfigAttr:
OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
def BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];
let convertFromStorage = "$_self";
}
def ConvolutionAttributes {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
@ -968,6 +979,8 @@ def ConvolutionAttributes {
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<BoolElementsAttr>:$window_reversal,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
@ -983,6 +996,14 @@ class BASE_HLO_ConvOp {
See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
}];
code extraClassDeclaration = [{
bool hasWindowReversal() {
auto reversal = window_reversalAttr();
return reversal && llvm::any_of(reversal.getBoolValues(),
[](bool v) { return v; });
}
}];
}
class BASE_HLO_CopyOp {

View File

@ -19,8 +19,8 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// Order matters, this .inc header is not self-contained, and relies on the

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {

View File

@ -25,12 +25,12 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

View File

@ -19,8 +19,8 @@
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// Order matters, this .inc header is not self-contained, and relies on the

View File

@ -21,7 +21,17 @@ include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
LHLO_GPU_Dialect, [
StructFieldAttr<"algorithm", I64Attr>,
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
StructFieldAttr<"tensor_ops_enabled", BoolAttr>,
// The following 3 attributes describe the layout as an array of integers
// that list the dimensions in minor-to-major order similar to XLA's layout
// representation. operand_0_layout and operand_0_layout described the layout
// of the first 2 operands of the convolution, and result_layout describes
// the layout of the primary output operand of the convolution.
// Note: Not using names like input_layout or filter_layout as `input` may be
// an input operand (for ConvForward) but output for ConvBackward.
StructFieldAttr<"operand_0_layout", I64ArrayAttr>,
StructFieldAttr<"operand_1_layout", I64ArrayAttr>,
StructFieldAttr<"result_layout", I64ArrayAttr>]> {
let description = "GPU Convolution backend configuration";
}

View File

@ -22,12 +22,12 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -65,12 +65,16 @@ MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(OrOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);
MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(RsqrtOp);
MAP_HLO_TO_LHLO(SelectOp);
MAP_HLO_TO_LHLO(ShiftLeftOp);
MAP_HLO_TO_LHLO(ShiftRightArithmeticOp);
MAP_HLO_TO_LHLO(ShiftRightLogicalOp);
MAP_HLO_TO_LHLO(SignOp);
MAP_HLO_TO_LHLO(SinOp);
MAP_HLO_TO_LHLO(SliceOp);
@ -78,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp);
MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp);
MAP_HLO_TO_LHLO(TransposeOp);
MAP_HLO_TO_LHLO(XorOp);
#undef MAP_HLO_TO_LHLO

View File

@ -37,6 +37,7 @@ template <>
struct LhloToScalarOp<lmhlo::AddOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
using COp = ::mlir::AddCFOp;
};
template <>
struct LhloToScalarOp<lmhlo::CompareOp> {
@ -62,20 +63,18 @@ template <>
struct LhloToScalarOp<lmhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
template <typename LhloBinaryOpTy>
struct ScalarOp {
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
using COp = ::mlir::SubCFOp;
};
// Alias for the map from LHLO binary op type to STD floating-point op type.
template <typename LhloOp>
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
// Alias for the map from LHLO binary op type to STD integer op type.
template <typename LhloOp>
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
// Alias for the map from LHLO binary op type to STD complex op type.
template <typename LhloOp>
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
template <typename... Args>
struct MapLhloOpToStdScalarOpImpl {
@ -143,6 +142,16 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
FloatType, ScalarFOp<lmhlo::AddOp>,
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
@ -172,7 +181,7 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
StringRef comparison_direction) {
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
.Case("EQ", CmpFPredicate::OEQ)
.Case("NE", CmpFPredicate::ONE)
.Case("NE", CmpFPredicate::UNE)
.Case("GE", CmpFPredicate::OGE)
.Case("GT", CmpFPredicate::OGT)
.Case("LE", CmpFPredicate::OLE)
@ -481,6 +490,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,
@ -498,6 +516,30 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Type> result_types,
@ -506,14 +548,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto float_type = element_type.dyn_cast<FloatType>()) {
bool ignored;
APFloat one_apfloat(1.0f);
one_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
APFloat zero_apfloat(0.0f);
zero_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored);
Value zero =
b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type);
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
}
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
Value ne0_i1 =
b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero);
Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType());
Value copy_sign =
b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]);
auto is_nan =
b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]);
return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero =
@ -547,6 +597,17 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
FloatType, ScalarFOp<lmhlo::SubOp>,
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
ArrayRef<Type> result_types,
@ -556,6 +617,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
loc, result_types, args, b);
}
} // namespace impl
struct HloOpToStdScalarOp {

View File

@ -52,6 +52,11 @@ void PopulateGatherToTorchIndexSelectPatterns(
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
void populateDynamicHLOToLHLOConversionPattern(
MLIRContext *context, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns, bool insert_copy = true);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern(MLIRContext *context,
BufferizeTypeConverter *converter,

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/BuiltinTypes.h"
namespace mlir {
namespace hlo {

View File

@ -18,8 +18,8 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
@ -83,6 +83,11 @@ enum ScalarLimit {
// Requires `ty` to be either FloatType or IntegerType.
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
// Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
// Returns empty string if no such op exists.
std::string LmhloToMhloOpName(llvm::StringRef op_name,
mlir::MLIRContext* context);
} // namespace hlo
} // namespace mlir

View File

@ -19,9 +19,9 @@ limitations under the License.
#include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
@ -51,7 +52,6 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
@ -454,18 +454,23 @@ static LogicalResult Verify(DynamicUpdateSliceOp op) {
// AbsOp
//===----------------------------------------------------------------------===//
void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) {
auto shaped_type = operand.getType().cast<ShapedType>();
Type new_type;
if (!shaped_type.getElementType().isa<ComplexType>()) {
new_type = operand.getType();
} else if (shaped_type.hasRank()) {
new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType());
} else {
new_type = UnrankedTensorType::get(operand.getType());
LogicalResult AbsOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
Type element_ty = operand_ty.getElementType();
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
element_ty = complex_ty.getElementType();
}
return AbsOp::build(builder, result, new_type, operand);
Type result_ty;
if (operand_ty.hasRank()) {
result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
} else {
result_ty = UnrankedTensorType::get(element_ty);
}
inferredReturnTypes.push_back(result_ty);
return success();
}
//===----------------------------------------------------------------------===//
@ -1879,28 +1884,29 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
++index[i];
if (index[i] < shape[i]) return true;
if (index[i] < shape[i]) return;
index[i] = 0;
}
return false;
};
// Iterate over all elements of the input tensor and copy it to the correct
// location in the output tensor.
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
do {
uint64_t linear_index = 0;
uint64_t linear_index_multiplyer = 1;
uint64_t num_elements = input.getNumElements();
for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
uint64_t result_idx = 0;
uint64_t idx_multiplyer = 1;
for (int64_t i = index.size() - 1; i >= 0; --i) {
linear_index +=
result_idx +=
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
index[i] *
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
linear_index_multiplyer;
linear_index_multiplyer *= return_type.getShape()[i];
idx_multiplyer;
idx_multiplyer *= return_type.getDimSize(i);
}
result[linear_index] = input.getValue(index);
} while (next_index(index, input.getType().getShape()));
result[result_idx] = input.getValue(index);
next_index(index, input.getType().getShape());
}
return DenseElementsAttr::get(return_type, result);
}
@ -2332,6 +2338,12 @@ static Attribute FoldSlice(SliceOp* op, I values) {
auto shape = result_type.getShape();
int64_t count = result_type.getNumElements();
if (count == 0) {
return DenseElementsAttr::get<E>(
op->getResult().getType().cast<ShapedType>(),
/*list=*/{});
}
// Compute the striding for each dimension.
llvm::SmallVector<int64_t, 6> sizes;
sizes.reserve(shape.size());

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
@ -39,7 +40,6 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
@ -40,7 +41,6 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"

View File

@ -112,6 +112,7 @@ add_mlir_library(MhloToStandard
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTensor
)
add_mlir_library(MhloLhloToLinalg

View File

@ -30,10 +30,10 @@ limitations under the License.
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {

View File

@ -23,10 +23,18 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
// Unary op patterns.
//===----------------------------------------------------------------------===//
def NonComplexElementType : Type<
CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
"Non complex element type">;
// Expand acos to MHLO dialect as follows:
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
// = pi if x == -1
def : Pat<(HLOClient_AcosOp $input),
//
// TODO(hinsu): Support operands with complex element types separately using
// the following formula.
// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
(HLO_SelectOp
(HLO_CompareOp
$input,
@ -68,7 +76,9 @@ def : Pat<(HLOClient_ConjOp $v),
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
def : Pat<(HLOClient_SinhOp $input),
// TODO(hinsu): Support operands with complex element types by always using the
// second formula. The compare op below is not legal for complex numbers.
def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
(HLO_SelectOp
(HLO_CompareOp
(HLO_AbsOp $input),

View File

@ -24,16 +24,17 @@ limitations under the License.
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
@ -62,7 +63,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
if (shape_element.value() != ShapedType::kDynamicSize) continue;
Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
Value alloc_operand =
rewriter->create<ExtractElementOp>(loc, shape_operand, index);
rewriter->create<tensor::ExtractOp>(loc, shape_operand, index);
if (!alloc_operand.getType().isIndex()) {
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
rewriter->getIndexType());
@ -184,32 +185,64 @@ struct HloToLhloCustomCallOpConverter
// for args and outputs.
const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
static_cast<int32_t>(op->getNumResults())};
lhloOp.setAttr(lhloOp.getOperandSegmentSizeAttr(),
rewriter.getI32VectorAttr(segments));
lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
rewriter.getI32VectorAttr(segments));
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter
// TODO(pifon): Consider inserting lhlo.copy as in
// HloToLhloDynamicBroadcastInDimOpConverter.
struct HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
};
class HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public:
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
MLIRContext* ctx,
bool insert_copy = true)
: BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
insert_copy_(insert_copy) {}
LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc(
loc, op.getResult(), op.output_dimensions(), &rewriter);
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
rewriter.replaceOp(op, {resultBuffer});
if (insert_copy_) {
auto loc = op.getLoc();
Value result_buffer = InsertDynamicAllocAndDealloc(
loc, op.getResult(), op.output_dimensions(), &rewriter);
rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
result = result_buffer;
}
rewriter.replaceOp(op, {result});
return success();
}
@ -260,7 +293,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
for (int i = 0; i < result_rank; ++i) {
Value i_val = b->create<ConstantIndexOp>(loc, i);
Value result_dim_size =
b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
b->create<tensor::ExtractOp>(loc, op.output_dimensions(), i_val);
if (!result_dim_size.getType().isIndex()) {
result_dim_size =
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
@ -307,31 +340,10 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
static_strides, llvm::None, sizes, strides);
return transformed_operand;
}
};
struct HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
// Keep the copy semantics and allocate a buffer for the result of the memref
// cast.
bool insert_copy_;
};
struct HloToLhloDotGeneralOpConverter
@ -428,7 +440,7 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
mhlo::ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
auto& entry_block = op.getParentRegion()->front();
auto& entry_block = op->getParentRegion()->front();
auto num_arguments = entry_block.getNumArguments();
if (operands.size() > num_arguments) {
return op.emitError(
@ -556,6 +568,7 @@ struct HloLegalizeToLhlo
ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalOp<mlir::TensorLoadOp>();
target.addIllegalOp<mlir::TensorStoreOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
@ -593,15 +606,22 @@ struct HloLegalizeToLhlo
};
} // namespace
void populateDynamicHLOToLHLOConversionPattern(
MLIRContext* context, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns, bool insert_copy) {
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
*converter, context, insert_copy);
patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
}
void populateHLOToLHLOConversionPattern(MLIRContext* context,
BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) {
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
// clang-format off
patterns->insert<
HloToLhloCustomCallOpConverter,
HloToLhloDotGeneralOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter,
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,
@ -629,11 +649,15 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::NotOp>,
HloToLhloOpConverter<mhlo::OrOp>,
HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>,
HloToLhloOpConverter<mhlo::ReshapeOp>,
HloToLhloOpConverter<mhlo::SelectOp>,
HloToLhloOpConverter<mhlo::ShiftLeftOp>,
HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
HloToLhloOpConverter<mhlo::SignOp>,
HloToLhloOpConverter<mhlo::SinOp>,
HloToLhloOpConverter<mhlo::SliceOp>,
@ -641,6 +665,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloOpConverter<mhlo::TransposeOp>,
HloToLhloOpConverter<mhlo::XorOp>,
HloToLhloReduceOpConverter,
HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter,

View File

@ -21,12 +21,13 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
#include "mlir/IR/Block.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
@ -83,7 +84,7 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
// Extract the predicate for checking branching, then branch to the true and
// false regions appropriately.
auto cond_value = builder.create<mlir::ExtractElementOp>(loc, if_op.pred());
auto cond_value = builder.create<mlir::tensor::ExtractOp>(loc, if_op.pred());
builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
if_op.true_arg(), false_block,
if_op.false_arg());
@ -142,7 +143,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand());
// Updates the inlined condition blocks by replacing the return op with an
// extract_element and conditional branch. This changes the block below:
// tensor.extract and conditional branch. This changes the block below:
// ^cond(%0):
// <inlined conditional region>
// "mhlo".return(%1)
@ -150,7 +151,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// Into:
// ^cond(%0):
// <inlined conditional region>
// %2 = extract_element %1[] : tensor<i1> // Extract the condition value.
// %2 = tensor.extract %1[] : tensor<i1> // Extract the condition value.
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
builder.setInsertionPointToStart(cond_block);
@ -166,7 +167,8 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
builder.setInsertionPointToEnd(new_block);
auto return_value = return_op.getOperand(0);
auto cond_value = builder.create<mlir::ExtractElementOp>(loc, return_value);
auto cond_value =
builder.create<mlir::tensor::ExtractOp>(loc, return_value);
// Get the body block arguments.
llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),

View File

@ -30,12 +30,12 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -49,17 +49,17 @@ SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
}
template <bool isLHLO = true>
Value getResultValue(Operation* op) {
Value GetResultValue(Operation* op) {
return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
}
template <bool isLHLO = true>
ShapedType getHloOpResultType(Operation* op) {
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
ShapedType GetHloOpResultType(Operation* op) {
return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
}
template <bool isLHLO = true>
bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
auto verify_type = [&](Value val) -> bool {
return (isLHLO && val.getType().isa<MemRefType>()) ||
(!isLHLO && val.getType().isa<RankedTensorType>());
@ -131,6 +131,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
common_indexing_map);
bool failed = false;
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, op_result_types, inputs, output_buffers,
/*initTensors=*/ValueRange{}, indexing_maps,
@ -141,8 +142,13 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, body_result_types,
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
nested_builder.create<linalg::YieldOp>(loc, op_result);
if (op_result == nullptr) {
failed = true;
} else {
nested_builder.create<linalg::YieldOp>(loc, op_result);
}
});
if (failed) return failure();
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success();
}
@ -243,7 +249,8 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
}
// TODO: LHS dilation for deconvolution not supported yet.
if (op.lhs_dilation()) {
// TODO(jurahul): Window reversal is not supported yet.
if (op.lhs_dilation() || op.hasWindowReversal()) {
return failure();
}
@ -292,8 +299,8 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto result_type = getHloOpResultType<isLHLO>(op);
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto result_type = GetHloOpResultType<isLHLO>(op);
SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter);
@ -330,7 +337,7 @@ class BroadcastConverter
ShapedType input_type =
broadcast_op.operand().getType().template cast<ShapedType>();
unsigned input_rank = input_type.getRank();
unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank();
unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions.
@ -364,7 +371,7 @@ class HloBroadcastInDimConverter
static SmallVector<AffineMap, 2> getIndexingMaps(
mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
auto result_type = getHloOpResultType<false>(broadcast_op);
auto result_type = GetHloOpResultType<false>(broadcast_op);
auto operand_type =
broadcast_op.operand().getType().template cast<ShapedType>();
unsigned nloops = result_type.getRank();
@ -562,7 +569,7 @@ class TransposeConverter
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> input_exprs;
input_exprs.resize(result_type.getRank());
@ -586,11 +593,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy reshape_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
return failure();
ShapedType operand_type =
reshape_op.operand().getType().template cast<ShapedType>();
ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op);
ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
@ -695,7 +702,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy iota_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op);
ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
if (!result_shaped_type) return failure();
auto result_element_type = result_shaped_type.getElementType();
@ -733,23 +740,37 @@ class IotaConverter : public OpConversionPattern<OpTy> {
}
};
class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
template <typename OpTy>
class ConstConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::ConstOp const_op, ArrayRef<Value> args,
OpTy const_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
auto loc = const_op.getLoc();
auto value_attr = const_op.value().cast<DenseElementsAttr>();
Location loc = const_op.getLoc();
auto value_attr = const_op.value().template cast<DenseElementsAttr>();
if (value_attr.getType().getRank() != 0) return failure();
auto std_const_op =
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
const_op.getOperand(), ValueRange());
rewriter.eraseOp(const_op);
ReplaceConstOp(loc, const_op, value_attr, rewriter);
return success();
}
private:
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
DenseElementsAttr value_attr,
ConversionPatternRewriter& rewriter) const {
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
rewriter.replaceOp(op, {std_tensor_const});
}
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
DenseElementsAttr value_attr,
ConversionPatternRewriter& rewriter) const {
Value std_scalar_const =
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
llvm::None);
rewriter.eraseOp(op);
}
};
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
@ -798,7 +819,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
loc, /*resultTensorTypes=*/ArrayRef<Type>{},
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
/*initTensors=*/ValueRange{}, maps, types);
linalg_op.region().takeBody(reduce_op.body());
rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
linalg_op.region().end());
{
OpBuilder::InsertionGuard region_guard(rewriter);
Block* block = linalg_op.getBody();
@ -852,7 +874,7 @@ class ReverseConverter
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> input_exprs;
input_exprs.reserve(nloops);
@ -908,7 +930,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter,
ConstConverter<lmhlo::ConstOp>,
ConvToLinalgConverter,
IotaConverter<lmhlo::IotaOp>,
LhloBroadcastInDimConverter,
@ -927,22 +949,27 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::FloorOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<lmhlo::MaxOp>,
PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::NotOp>,
PointwiseToLinalgConverter<lmhlo::OrOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
PointwiseToLinalgConverter<lmhlo::SelectOp>,
PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
PointwiseToLinalgConverter<lmhlo::SignOp>,
PointwiseToLinalgConverter<lmhlo::SinOp>,
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
PointwiseToLinalgConverter<lmhlo::XorOp>,
ReduceConverter,
ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>,
@ -1024,7 +1051,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
ConstConverter<mhlo::ConstOp>, HloBroadcastInDimConverter,
IotaConverter<mhlo::IotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
@ -1039,21 +1067,27 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::OrOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
PointwiseToLinalgConverter<mhlo::SignOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
PointwiseToLinalgConverter<mhlo::XorOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);

View File

@ -19,8 +19,8 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@ -30,11 +30,11 @@ limitations under the License.
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -437,12 +437,14 @@ class ReduceWindowOpConverter
loc, operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
OpBuilder then_builder =
elem_or_init.getThenBodyBuilder(rewriter->getListener());
Value elem = then_builder.create<mlir::LoadOp>(
loc, reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
OpBuilder else_builder =
elem_or_init.getElseBodyBuilder(rewriter->getListener());
else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
return rewriter->create<scf::ReduceOp>(loc,
@ -617,7 +619,8 @@ class SelectAndScatterOpConverter
// Case when we are inside boundaries of 'arg' and not in the pad area.
{
OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
OpBuilder in_bounds_then_b =
if_in_bounds.getThenBodyBuilder(b->getListener());
auto select_or_init_results = SelectOrInitialize(
s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
@ -625,7 +628,8 @@ class SelectAndScatterOpConverter
// Case when we are in the pad.
{
OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
OpBuilder in_bounds_else_b =
if_in_bounds.getElseBodyBuilder(b->getListener());
in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
}
@ -651,7 +655,7 @@ class SelectAndScatterOpConverter
// element in boundaries of the operand. Select function has to be computed
// here.
{
OpBuilder if_init_then_b = if_init.getThenBodyBuilder();
OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
auto& lhlo_select = s_and_s_op.select().front();
Value pred =
@ -664,14 +668,14 @@ class SelectAndScatterOpConverter
// Pred == true, therefore pack newly selected ivs, val and init flag back
// to iter_args and return.
{
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
if_pred_then_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
}
// Pred == false, therefore return old iter_args.
{
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
}
@ -680,7 +684,7 @@ class SelectAndScatterOpConverter
// Init == false, i.e. only pad was visited before and this is the first
// element in the boundaries of the operand.
{
OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
if_init_else_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());

View File

@ -23,9 +23,9 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@ -18,9 +18,10 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@ -119,7 +120,7 @@ void MatchAndRewrite(WhileOp whileOp) {
auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
auto getAsIndex = [&](Value val) {
auto loc = whileOp.getLoc();
return b.create<ExtractElementOp>(
return b.create<tensor::ExtractOp>(
loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
};

View File

@ -22,10 +22,10 @@ limitations under the License.
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -42,11 +42,11 @@ namespace {
sep fn(SqrtOp) sep fn(TanhOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
@ -409,7 +409,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
constexpr int kMaxRankSpecialization = 5;
constexpr int kMaxRankSpecialization = 6;
for (int i = 2; i < kMaxRankSpecialization; i++) {
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
else_builder, op, greater_rank, i);

View File

@ -18,9 +18,9 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Transforms/DialectConversion.h"

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
namespace hlo {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "mlir-hlo/utils/convert_op_folder.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {

View File

@ -132,5 +132,13 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
llvm_unreachable("unsupported type");
}
std::string LmhloToMhloOpName(llvm::StringRef op_name,
mlir::MLIRContext *context) {
assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
std::string mhlo_op_name(op_name.drop_front(1));
if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
return "";
}
} // namespace hlo
} // namespace mlir

View File

@ -327,6 +327,15 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
return %1 : tensor<4x1xi64>
}
// CHECK-LABEL: slice_zero_elements
func @slice_zero_elements() -> tensor<0xi64> {
%0 = mhlo.constant dense<> : tensor<0xi64>
// CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<0xi64>) -> (tensor<0xi64>)
// CHECK: return %[[CONST]] : tensor<0xi64>
return %1 : tensor<0xi64>
}
// CHECK-LABEL: slice_unknown_shape
func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
@ -1506,6 +1515,14 @@ func @pad_fold() -> tensor<4x5xi32> {
// CHECK-SAME: ]> : tensor<4x5xi32>
}
func @pad_fold_zero_elements() -> tensor<3xi32> {
%0 = mhlo.constant dense<> : tensor<0xi32>
%1 = mhlo.constant dense<7> : tensor<i32>
%2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<0xi32>, tensor<i32>) -> tensor<3xi32>
return %2 : tensor<3xi32>
// CHECK: mhlo.constant dense<7> : tensor<3xi32>
}
// CHECK-LABEL: @identity_broadcast_reshape
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast"(%arg0) {

View File

@ -170,24 +170,31 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
return %rank : index
}
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[EL1:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
@ -316,6 +323,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @and
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -389,6 +410,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
// -----
// CHECK-LABEL: func @or
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -425,6 +460,48 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @shift_left
func @shift_left(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
%tensor_result = "mhlo.shift_left"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @shift_right_arithmetic
func @shift_right_arithmetic(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
%tensor_result = "mhlo.shift_right_arithmetic"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @shift_right_logical
func @shift_right_logical(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
%tensor_result = "mhlo.shift_right_logical"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -438,7 +515,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
@ -450,6 +528,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// -----
// CHECK-LABEL: func @xor
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// Dynamic shape binary element-wise operation.
// CHECK-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
@ -462,9 +554,9 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -485,9 +577,9 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()

View File

@ -29,6 +29,18 @@ func @integer_add(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: complex_add
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: addcf
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_mul
func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
@ -112,6 +124,18 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: complex_sub
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: subcf
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_abs
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
@ -194,6 +218,30 @@ func @integer_and(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @integer_or
func @integer_or(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: or
%0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @integer_xor
func @integer_xor(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: xor
%0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
@ -208,6 +256,20 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
// -----
// CHECK-LABEL: func @float_cmp_ne
func @float_cmp_ne(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
@ -630,3 +692,56 @@ func @iota() -> tensor<7x10xf32> {
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
// -----
func @shift_left(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_left"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_left
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_right_arithmetic
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @shift_right_logical(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_right_logical"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_right_logical
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @constant
func @constant() {
%result = "mhlo.constant"() {
value = dense<10> : tensor<i32>
} : () -> (tensor<i32>)
return
}
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>

View File

@ -262,17 +262,34 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
// Handle rank 5 specialization
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]]
// Handle rank 6 specialization
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
// CHECK-NEXT: }

View File

@ -5,7 +5,7 @@ func @while(%arg0: tensor<i64>) -> tensor<i64> {
//CHECK: br ^bb1(%arg0 : tensor<i64>)
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
//CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
//CHECK: [[VAL2:%.+]] = tensor.extract [[VAL1]][] : tensor<i1>
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
//CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
@ -33,7 +33,7 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
// CHECK: [[VAL1:%.+]] = tensor.extract [[VAL0]][] : tensor<i1>
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
%1 = "mhlo.if"(%0, %arg0, %arg0) ( {
@ -63,7 +63,7 @@ func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %2 = extract_element %1[] : tensor<i1>
// CHECK: %2 = tensor.extract %1[] : tensor<i1>
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
@ -95,7 +95,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
// CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: %3 = tensor.extract %2[] : tensor<i1>
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
// CHECK: br ^[[COND_ENTRY]](%4 : tensor<i64>)
@ -118,7 +118,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
// CHECK: %0 = extract_element %arg2[] : tensor<i1>
// CHECK: %0 = tensor.extract %arg2[] : tensor<i1>
// CHECK: cond_br %0, ^[[THEN_ENTRY:.+]](%arg0 : tensor<f32>), ^[[ELSE_ENTRY:.+]](%arg1 : tensor<f32>)
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)

View File

@ -30,9 +30,9 @@ func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg
// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
// CHECK: %[[VAL_15:.*]] = tensor.extract %[[VAL_14]][] : tensor<index>
// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
// CHECK: %[[VAL_17:.*]] = tensor.extract %[[VAL_16]][] : tensor<index>
// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
// CHECK: %[[VAL_19:.*]] = tensor.extract %[[VAL_18]][] : tensor<index>
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])

View File

@ -594,8 +594,12 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : f32
// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to f32
// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -607,8 +611,12 @@ func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16
// CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : bf16
// CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : bf16
// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to bf16
// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : bf16
// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : bf16
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : bf16
// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16
// -----

Some files were not shown because too many files have changed in this diff Show More